From 390e60738fe8d602b50a69e885996132143669f8 Mon Sep 17 00:00:00 2001 From: Paul Brinkmeier Date: Sat, 16 Sep 2023 06:17:08 +0200 Subject: [PATCH] Implement column table stuff --- README.md | 3 +- lib/Database/PostgreSQL/Opium.hs | 92 +++++++++++----------- lib/Database/PostgreSQL/Opium/Error.hs | 3 +- lib/Database/PostgreSQL/Opium/FromField.hs | 54 ++++++------- test/Database/PostgreSQL/OpiumSpec.hs | 35 ++++++-- 5 files changed, 101 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index 00c0595..f04438e 100644 --- a/README.md +++ b/README.md @@ -11,4 +11,5 @@ - [ ] Implement `UTCTime` and zoned time decoding - [ ] Implement JSON decoding - [ ] Implement `ByteString` decoding (`bytea`) - - Can we make let the fromField instance choose whether it wants binary or text? + - Can we make the fromField instance choose whether it wants binary or text? +- [ ] Clean up and document column table stuff diff --git a/lib/Database/PostgreSQL/Opium.hs b/lib/Database/PostgreSQL/Opium.hs index abaab2e..6813f85 100644 --- a/lib/Database/PostgreSQL/Opium.hs +++ b/lib/Database/PostgreSQL/Opium.hs @@ -19,6 +19,7 @@ module Database.PostgreSQL.Opium where import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) +import Control.Monad.Trans.State (StateT (..), evalStateT, modify) import Data.ByteString (ByteString) import Data.Proxy (Proxy (..)) import Data.Text (Text) @@ -36,7 +37,7 @@ import qualified Data.Text.Encoding as Encoding import qualified Database.PostgreSQL.LibPQ as LibPQ import Database.PostgreSQL.Opium.Error (Error (..)) -import Database.PostgreSQL.Opium.FromField (FieldError (..), FromField (..)) +import Database.PostgreSQL.Opium.FromField (FieldError (..), FromField (..), fromField) execParams :: Connection -> ByteString -> IO (Either Error Result) execParams conn query = do @@ -51,24 +52,24 @@ execParams conn query = do Just message -> pure $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message Nothing -> pure $ Right result -fetch_ :: FromRow a => Connection -> ByteString -> IO (Either Error [a]) +fetch_ :: forall a. FromRow a => Connection -> ByteString -> IO (Either Error [a]) fetch_ conn query = runExceptT $ do result <- ExceptT $ execParams conn query - ExceptT $ fetchResult result + -- TODO: Use unboxed array for columnTable + columnTable <- ExceptT $ getColumnTable @a Proxy result + nRows <- ExceptT $ Right <$> LibPQ.ntuples result + mapM (ExceptT . fromRow result columnTable) [0..nRows - 1] -fetchResult :: FromRow a => Result -> IO (Either Error [a]) -fetchResult result = do - nRows <- LibPQ.ntuples result - runExceptT $ mapM (ExceptT . flip fromRow result) [0..nRows - 1] +type ColumnTable = [Column] class FromRow a where getColumnTable :: Proxy a -> Result -> IO (Either Error [Column]) default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error [Column]) getColumnTable Proxy = runExceptT . getColumnTable' @(Rep a) Proxy - fromRow :: Row -> Result -> IO (Either Error a) - default fromRow :: (Generic a, FromRow' (Rep a)) => Row -> Result -> IO (Either Error a) - fromRow row result = fmap to <$> fromRow' row result + fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a) + default fromRow :: (Generic a, FromRow' (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a) + fromRow result columnTable row = evalStateT (fmap to <$> fromRow' result columnTable row) 0 class GetColumnTable' f where getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [Column] @@ -83,15 +84,17 @@ instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) whe getColumnTable' Proxy result = (++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result -checkColumn :: FromField f => Proxy f -> String -> Result -> ExceptT Error IO [Column] +checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [Column] checkColumn Proxy nameStr result = do - column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name - -- TODO: Rewrite FromField to check whether oid works for decoding t - _oid <- ExceptT $ Right <$> LibPQ.ftype result column + column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name + oid <- ExceptT $ Right <$> LibPQ.ftype result column + if validOid @a Proxy oid then pure [column] - where - nameText = Text.pack nameStr - name = Encoding.encodeUtf8 nameText + else + except $ Left $ ErrorInvalidOid nameText oid + where + nameText = Text.pack nameStr + name = Encoding.encodeUtf8 nameText instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy @@ -100,49 +103,48 @@ instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTabl getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy class FromRow' f where - fromRow' :: Row -> Result -> IO (Either Error (f p)) + fromRow' :: Result -> ColumnTable -> Row -> StateT Int IO (Either Error (f p)) instance FromRow' f => FromRow' (M1 D c f) where - fromRow' row result = fmap M1 <$> fromRow' row result + fromRow' result columnTable row = fmap M1 <$> fromRow' result columnTable row instance FromRow' f => FromRow' (M1 C c f) where - fromRow' row result = fmap M1 <$> fromRow' row result + fromRow' result columnTable row = fmap M1 <$> fromRow' result columnTable row instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where - fromRow' row result = do - y <- fromRow' row result - z <- fromRow' row result + fromRow' result columnTable row = do + y <- fromRow' result columnTable row + modify (+1) + z <- fromRow' result columnTable row pure $ (:*:) <$> y <*> z decodeField :: FromField t => Text -> (Row -> Maybe t -> Either Error t') - -> Row -> Result - -> IO (Either Error (M1 S m (Rec0 t') p)) -decodeField nameText g row result = runExceptT $ do - column <- getColumn - oid <- ExceptT $ pure <$> LibPQ.ftype result column - mbField <- getFieldText column - mbValue <- getValue oid mbField - value <- except $ g row mbValue - pure $ M1 $ K1 value - where - name = Encoding.encodeUtf8 nameText + -> ColumnTable + -> Row + -> StateT Int IO (Either Error (M1 S m (Rec0 t') p)) +decodeField nameText g result columnTable row = StateT $ \i -> do + v <- runExceptT $ do + let column = columnTable !! i + oid <- ExceptT $ pure <$> LibPQ.ftype result column + mbField <- getFieldText column + mbValue <- getValue oid mbField + value <- except $ g row mbValue + pure $ M1 $ K1 value - getColumn :: ExceptT Error IO Column - getColumn = ExceptT $ - maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name + pure (v, i) + where + getFieldText :: Column -> ExceptT Error IO (Maybe Text) + getFieldText column = + ExceptT $ Right . fmap Encoding.decodeUtf8 <$> LibPQ.getvalue result row column - getFieldText :: Column -> ExceptT Error IO (Maybe Text) - getFieldText column = - ExceptT $ Right . fmap Encoding.decodeUtf8 <$> LibPQ.getvalue result row column - - getValue :: FromField u => LibPQ.Oid -> Maybe Text -> ExceptT Error IO (Maybe u) - getValue oid = except . maybe - (Right Nothing) - (fmap Just . mapLeft (ErrorDecode row nameText) . fromField oid) + getValue :: FromField u => LibPQ.Oid -> Maybe Text -> ExceptT Error IO (Maybe u) + getValue oid = except . maybe + (Right Nothing) + (fmap Just . mapLeft (ErrorDecode row nameText) . fromField oid) instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where fromRow' = decodeField nameText $ \row -> maybe diff --git a/lib/Database/PostgreSQL/Opium/Error.hs b/lib/Database/PostgreSQL/Opium/Error.hs index e5c7f07..f655031 100644 --- a/lib/Database/PostgreSQL/Opium/Error.hs +++ b/lib/Database/PostgreSQL/Opium/Error.hs @@ -1,7 +1,7 @@ module Database.PostgreSQL.Opium.Error (Error (..)) where import Data.Text (Text) -import Database.PostgreSQL.LibPQ (ExecStatus, Row) +import Database.PostgreSQL.LibPQ (ExecStatus, Oid, Row) import Database.PostgreSQL.Opium.FromField (FieldError) @@ -10,5 +10,6 @@ data Error | ErrorNoResult | ErrorInvalidResult ExecStatus Text | ErrorMissingColumn Text + | ErrorInvalidOid Text Oid | ErrorUnexpectedNull Row Text deriving (Eq, Show) diff --git a/lib/Database/PostgreSQL/Opium/FromField.hs b/lib/Database/PostgreSQL/Opium/FromField.hs index a407ffb..6290d3d 100644 --- a/lib/Database/PostgreSQL/Opium/FromField.hs +++ b/lib/Database/PostgreSQL/Opium/FromField.hs @@ -1,9 +1,11 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeApplications #-} module Database.PostgreSQL.Opium.FromField ( FieldError (..) , FromField (..) + , fromField ) where import Data.Attoparsec.Text @@ -17,6 +19,7 @@ import Data.Attoparsec.Text , takeText ) import Data.Functor (($>)) +import Data.Proxy (Proxy (..)) import Data.Text (Text) import Database.PostgreSQL.LibPQ (Oid) import GHC.Float (double2Float) @@ -29,8 +32,6 @@ p \/ q = \x -> p x || q x data FieldError = FieldErrorUnexpectedNull - -- TODO: Move this to the normal Error - | FieldErrorInvalidOid Oid | FieldErrorInvalidField Oid Text String deriving (Eq, Show) @@ -38,36 +39,28 @@ mapLeft :: (b -> c) -> Either b a -> Either c a mapLeft f (Left l) = Left $ f l mapLeft _ (Right r) = Right r -fromParser - :: (Oid -> Bool) - -> Parser a - -> Oid - -> Text - -> Either FieldError a -fromParser validOid parser oid field - | validOid oid = - mapLeft (FieldErrorInvalidField oid field) $ parseOnly parser field - | otherwise = - Left $ FieldErrorInvalidOid oid +fromField :: FromField a => Oid -> Text -> Either FieldError a +fromField oid field = + mapLeft (FieldErrorInvalidField oid field) $ parseOnly parseField field class FromField a where - fromField :: Oid -> Text -> Either FieldError a + validOid :: Proxy a -> Oid -> Bool + parseField :: Parser a instance FromField Text where - fromField = fromParser - (Oid.text \/ Oid.character \/ Oid.characterVarying) - takeText + validOid _ = Oid.text \/ Oid.character \/ Oid.characterVarying + parseField = takeText instance FromField String where - fromField oid text = Text.unpack <$> fromField oid text + validOid _ = validOid @Text Proxy + parseField = Text.unpack <$> parseField instance FromField Int where - fromField = fromParser - (Oid.smallint \/ Oid.integer \/ Oid.bigint) - (signed decimal) + validOid _ = Oid.smallint \/ Oid.integer \/ Oid.bigint + parseField = signed decimal -floatParser :: Parser Double -floatParser = choice +doubleParser :: Parser Double +doubleParser = choice [ string "NaN" $> nan , signed (string "Infinity" $> infinity) , double @@ -77,14 +70,12 @@ floatParser = choice infinity = 1 / 0 instance FromField Float where - fromField = fromParser - Oid.real - (fmap double2Float floatParser) + validOid _ = Oid.real + parseField = fmap double2Float doubleParser instance FromField Double where - fromField = fromParser - (Oid.real \/ Oid.doublePrecision) - floatParser + validOid _ = Oid.real \/ Oid.doublePrecision + parseField = doubleParser boolParser :: Parser Bool boolParser = choice @@ -94,6 +85,5 @@ boolParser = choice -- | See https://www.postgresql.org/docs/current/datatype-boolean.html. instance FromField Bool where - fromField = fromParser - Oid.boolean - boolParser + validOid _ = Oid.boolean + parseField = boolParser diff --git a/test/Database/PostgreSQL/OpiumSpec.hs b/test/Database/PostgreSQL/OpiumSpec.hs index 86c8a12..a60e46c 100644 --- a/test/Database/PostgreSQL/OpiumSpec.hs +++ b/test/Database/PostgreSQL/OpiumSpec.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -26,6 +27,16 @@ newtype MaybeTest = MaybeTest instance Opium.FromRow MaybeTest where +data ManyFields = ManyFields + { a :: Text + , b :: Int + , c :: Double + , d :: String + , e :: Bool + } deriving (Eq, Generic, Show) + +instance Opium.FromRow ManyFields where + isLeft :: Either a b -> Bool isLeft (Left _) = True isLeft _ = False @@ -43,37 +54,47 @@ spec = do columnTable0 <- Opium.getColumnTable @Person Proxy result0 columnTable0 `shouldBe` Right [1, 0] - Just result1 <- LibPQ.execParams conn "SELECT 0 a, 1 b, 2 c, age, 4 d, name FROM person" [] LibPQ.Text + Just result1 <- LibPQ.execParams conn "SELECT 0 AS a, 1 AS b, 2 AS c, age, 4 AS d, name FROM person" [] LibPQ.Text columnTable1 <- Opium.getColumnTable @Person Proxy result1 columnTable1 `shouldBe` Right [5, 3] it "Fails for missing columns" $ \conn -> do - Just result <- LibPQ.execParams conn "SELECT 0 a FROM person" [] LibPQ.Text + Just result <- LibPQ.execParams conn "SELECT 0 AS a FROM person" [] LibPQ.Text columnTable <- Opium.getColumnTable @Person Proxy result columnTable `shouldBe` Left (Opium.ErrorMissingColumn "name") describe "fromRow" $ do it "Decodes rows in a Result" $ \conn -> do Just result <- LibPQ.execParams conn "SELECT * FROM person" [] LibPQ.Text + Right columnTable <- Opium.getColumnTable @Person Proxy result - row0 <- Opium.fromRow @Person (LibPQ.Row 0) result + row0 <- Opium.fromRow @Person result columnTable (LibPQ.Row 0) row0 `shouldBe` Right (Person "paul" 25) - row1 <- Opium.fromRow @Person (LibPQ.Row 1) result + row1 <- Opium.fromRow @Person result columnTable (LibPQ.Row 1) row1 `shouldBe` Right (Person "albus" 103) it "Decodes NULL into Nothing for Maybes" $ \conn -> do Just result <- LibPQ.execParams conn "SELECT NULL AS a" [] LibPQ.Text + Right columnTable <- Opium.getColumnTable @MaybeTest Proxy result - row <- Opium.fromRow (LibPQ.Row 0) result + row <- Opium.fromRow result columnTable (LibPQ.Row 0) row `shouldBe` Right (MaybeTest Nothing) it "Decodes values into Just for Maybes" $ \conn -> do Just result <- LibPQ.execParams conn "SELECT 'abc' AS a" [] LibPQ.Text + Right columnTable <- Opium.getColumnTable @MaybeTest Proxy result - row <- Opium.fromRow (LibPQ.Row 0) result + row <- Opium.fromRow result columnTable (LibPQ.Row 0) row `shouldBe` Right (MaybeTest $ Just "abc") + it "Works for many fields" $ \conn -> do + Just result <- LibPQ.execParams conn "SELECT 'abc' AS a, 42 AS b, 1.0::double precision AS c, 'test' AS d, true AS e" [] LibPQ.Text + Right columnTable <- Opium.getColumnTable @ManyFields Proxy result + + row <- Opium.fromRow result columnTable (LibPQ.Row 0) + row `shouldBe` Right (ManyFields "abc" 42 1.0 "test" True) + describe "fetch_" $ do it "Retrieves a list of rows" $ \conn -> do rows <- Opium.fetch_ conn "SELECT * FROM person" @@ -89,4 +110,4 @@ spec = do it "Fails for the wrong column type" $ \conn -> do rows <- Opium.fetch_ @Person conn "SELECT 'quby' AS name, 'indeterminate' AS age" - rows `shouldBe` Left (Opium.ErrorDecode (LibPQ.Row 0) "age" $ Opium.FieldErrorInvalidOid $ LibPQ.Oid 25) + rows `shouldBe` Left (Opium.ErrorInvalidOid "age" $ LibPQ.Oid 25)