Clean up column table code

This commit is contained in:
Paul Brinkmeier 2023-09-16 18:39:45 +02:00
parent 390e60738f
commit 335b6188d1
4 changed files with 91 additions and 89 deletions

View File

@ -11,21 +11,23 @@
module Database.PostgreSQL.Opium module Database.PostgreSQL.Opium
( Error (..) ( Error (..)
, FieldError (..) , ErrorPosition (..)
, FromField (..) , FromField (..)
, FromRow (..) , FromRow (..)
, fetch_ , fetch_
) )
where where
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT)
import Control.Monad.Trans.State (StateT (..), evalStateT, modify)
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Data.IORef (IORef, modifyIORef', newIORef, readIORef)
import Data.Proxy (Proxy (..)) import Data.Proxy (Proxy (..))
import Data.Text (Text) import Data.Text (Text)
import Database.PostgreSQL.LibPQ import Database.PostgreSQL.LibPQ
( Column ( Column
, Connection , Connection
, Oid
, Result , Result
, Row , Row
) )
@ -36,43 +38,48 @@ import qualified Data.Text as Text
import qualified Data.Text.Encoding as Encoding import qualified Data.Text.Encoding as Encoding
import qualified Database.PostgreSQL.LibPQ as LibPQ import qualified Database.PostgreSQL.LibPQ as LibPQ
import Database.PostgreSQL.Opium.Error (Error (..)) import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..))
import Database.PostgreSQL.Opium.FromField (FieldError (..), FromField (..), fromField) import Database.PostgreSQL.Opium.FromField (FromField (..), fromField)
execParams :: Connection -> ByteString -> IO (Either Error Result) execParams :: Connection -> ByteString -> ExceptT Error IO Result
execParams conn query = do execParams conn query = do
LibPQ.execParams conn query [] LibPQ.Text >>= \case liftIO (LibPQ.execParams conn query [] LibPQ.Text) >>= \case
Nothing -> Nothing ->
pure $ Left ErrorNoResult except $ Left ErrorNoResult
Just result -> do Just result -> do
status <- LibPQ.resultStatus result status <- liftIO $ LibPQ.resultStatus result
mbMessage <- LibPQ.resultErrorMessage result mbMessage <- liftIO $ LibPQ.resultErrorMessage result
case mbMessage of case mbMessage of
Just "" -> pure $ Right result Just "" -> pure result
Just message -> pure $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message Nothing -> pure result
Nothing -> pure $ Right result Just message -> except $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message
fetch_ :: forall a. FromRow a => Connection -> ByteString -> IO (Either Error [a]) fetch_ :: forall a. FromRow a => Connection -> ByteString -> IO (Either Error [a])
fetch_ conn query = runExceptT $ do fetch_ conn query = runExceptT $ do
result <- ExceptT $ execParams conn query result <- execParams conn query
-- TODO: Use unboxed array for columnTable -- TODO: Use unboxed array for columnTable
columnTable <- ExceptT $ getColumnTable @a Proxy result columnTable <- ExceptT $ getColumnTable @a Proxy result
nRows <- ExceptT $ Right <$> LibPQ.ntuples result nRows <- liftIO $ LibPQ.ntuples result
mapM (ExceptT . fromRow result columnTable) [0..nRows - 1] mapM (ExceptT . fromRow result columnTable) [0..nRows - 1]
type ColumnTable = [Column] type ColumnTable = [(Column, Oid)]
indexColumnTable :: ColumnTable -> Int -> (Column, Oid)
indexColumnTable = (!!)
class FromRow a where class FromRow a where
getColumnTable :: Proxy a -> Result -> IO (Either Error [Column]) getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable)
default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error [Column]) default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable)
getColumnTable Proxy = runExceptT . getColumnTable' @(Rep a) Proxy getColumnTable Proxy = runExceptT . getColumnTable' @(Rep a) Proxy
fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a) fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a)
default fromRow :: (Generic a, FromRow' (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a) default fromRow :: (Generic a, FromRow' (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a)
fromRow result columnTable row = evalStateT (fmap to <$> fromRow' result columnTable row) 0 fromRow result columnTable row = do
iRef <- newIORef 0
runExceptT $ to <$> fromRow' (FromRowCtx result columnTable iRef) row
class GetColumnTable' f where class GetColumnTable' f where
getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [Column] getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO ColumnTable
instance GetColumnTable' f => GetColumnTable' (M1 D c f) where instance GetColumnTable' f => GetColumnTable' (M1 D c f) where
getColumnTable' Proxy = getColumnTable' @f Proxy getColumnTable' Proxy = getColumnTable' @f Proxy
@ -84,12 +91,12 @@ instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) whe
getColumnTable' Proxy result = getColumnTable' Proxy result =
(++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result (++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result
checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [Column] checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO ColumnTable
checkColumn Proxy nameStr result = do checkColumn Proxy nameStr result = do
column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name
oid <- ExceptT $ Right <$> LibPQ.ftype result column oid <- liftIO $ LibPQ.ftype result column
if validOid @a Proxy oid then if validOid @a Proxy oid then
pure [column] pure [(column, oid)]
else else
except $ Left $ ErrorInvalidOid nameText oid except $ Left $ ErrorInvalidOid nameText oid
where where
@ -102,62 +109,60 @@ instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTab
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
data FromRowCtx = FromRowCtx Result ColumnTable (IORef Int)
class FromRow' f where class FromRow' f where
fromRow' :: Result -> ColumnTable -> Row -> StateT Int IO (Either Error (f p)) fromRow' :: FromRowCtx -> Row -> ExceptT Error IO (f p)
instance FromRow' f => FromRow' (M1 D c f) where instance FromRow' f => FromRow' (M1 D c f) where
fromRow' result columnTable row = fmap M1 <$> fromRow' result columnTable row fromRow' ctx row = M1 <$> fromRow' ctx row
instance FromRow' f => FromRow' (M1 C c f) where instance FromRow' f => FromRow' (M1 C c f) where
fromRow' result columnTable row = fmap M1 <$> fromRow' result columnTable row fromRow' ctx row = M1 <$> fromRow' ctx row
instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where
fromRow' result columnTable row = do fromRow' ctx@(FromRowCtx _ _ iRef) row = do
y <- fromRow' result columnTable row y <- fromRow' ctx row
modify (+1) liftIO $ modifyIORef' iRef (+1)
z <- fromRow' result columnTable row z <- fromRow' ctx row
pure $ (:*:) <$> y <*> z pure $ y :*: z
decodeField decodeField
:: FromField t :: FromField t
=> Text => Text
-> (Row -> Maybe t -> Either Error t') -> (Row -> Maybe t -> Either Error t')
-> Result -> FromRowCtx
-> ColumnTable
-> Row -> Row
-> StateT Int IO (Either Error (M1 S m (Rec0 t') p)) -> ExceptT Error IO (M1 S m (Rec0 t') p)
decodeField nameText g result columnTable row = StateT $ \i -> do decodeField nameText g (FromRowCtx result columnTable iRef) row = do
v <- runExceptT $ do i <- liftIO $ readIORef iRef
let column = columnTable !! i let (column, oid) = columnTable `indexColumnTable` i
oid <- ExceptT $ pure <$> LibPQ.ftype result column mbField <- liftIO $ getFieldText column
mbField <- getFieldText column mbValue <- except $ getValue oid mbField
mbValue <- getValue oid mbField value <- except $ g row mbValue
value <- except $ g row mbValue pure $ M1 $ K1 value
pure $ M1 $ K1 value
pure (v, i)
where where
getFieldText :: Column -> ExceptT Error IO (Maybe Text) getFieldText :: Column -> IO (Maybe Text)
getFieldText column = getFieldText column =
ExceptT $ Right . fmap Encoding.decodeUtf8 <$> LibPQ.getvalue result row column fmap Encoding.decodeUtf8 <$> LibPQ.getvalue result row column
getValue :: FromField u => LibPQ.Oid -> Maybe Text -> ExceptT Error IO (Maybe u) getValue :: FromField u => LibPQ.Oid -> Maybe Text -> Either Error (Maybe u)
getValue oid = except . maybe getValue oid = maybe (Right Nothing) $ \fieldText ->
(Right Nothing) mapLeft
(fmap Just . mapLeft (ErrorDecode row nameText) . fromField oid) (ErrorInvalidField (ErrorPosition row nameText) oid fieldText)
(Just <$> fromField fieldText)
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
fromRow' = decodeField nameText $ \row -> maybe
(Left $ ErrorUnexpectedNull row nameText)
Right
where
nameText = Text.pack $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
fromRow' = decodeField nameText $ const Right
where
nameText = Text.pack $ symbolVal @nameSym Proxy
mapLeft :: (b -> c) -> Either b a -> Either c a mapLeft :: (b -> c) -> Either b a -> Either c a
mapLeft f (Left l) = Left $ f l mapLeft f (Left l) = Left $ f l
mapLeft _ (Right r) = Right r mapLeft _ (Right r) = Right r
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
fromRow' = decodeField nameText $ \row ->
maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right
where
nameText = Text.pack $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
fromRow' = decodeField nameText $ const pure
where
nameText = Text.pack $ symbolVal @nameSym Proxy

View File

@ -1,15 +1,21 @@
module Database.PostgreSQL.Opium.Error (Error (..)) where module Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..)) where
import Control.Exception (Exception)
import Data.Text (Text) import Data.Text (Text)
import Database.PostgreSQL.LibPQ (ExecStatus, Oid, Row) import Database.PostgreSQL.LibPQ (ExecStatus, Oid, Row)
import Database.PostgreSQL.Opium.FromField (FieldError) data ErrorPosition = ErrorPosition
{ errorPositionRow :: Row
, errorPositionColumn :: Text
} deriving (Eq, Show)
data Error data Error
= ErrorDecode Row Text FieldError = ErrorNoResult
| ErrorNoResult
| ErrorInvalidResult ExecStatus Text | ErrorInvalidResult ExecStatus Text
| ErrorMissingColumn Text | ErrorMissingColumn Text
| ErrorInvalidOid Text Oid | ErrorInvalidOid Text Oid
| ErrorUnexpectedNull Row Text | ErrorUnexpectedNull ErrorPosition
| ErrorInvalidField ErrorPosition Oid Text String
deriving (Eq, Show) deriving (Eq, Show)
instance Exception Error where

View File

@ -3,8 +3,7 @@
{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeApplications #-}
module Database.PostgreSQL.Opium.FromField module Database.PostgreSQL.Opium.FromField
( FieldError (..) ( FromField (..)
, FromField (..)
, fromField , fromField
) where ) where
@ -25,23 +24,15 @@ import Database.PostgreSQL.LibPQ (Oid)
import GHC.Float (double2Float) import GHC.Float (double2Float)
import qualified Data.Text as Text import qualified Data.Text as Text
import qualified Database.PostgreSQL.Opium.Oid as Oid import qualified Database.PostgreSQL.Opium.Oid as Oid
(\/) :: (a -> Bool) -> (a -> Bool) -> a -> Bool (\/) :: (a -> Bool) -> (a -> Bool) -> a -> Bool
p \/ q = \x -> p x || q x p \/ q = \x -> p x || q x
data FieldError fromField :: FromField a => Text -> Either String a
= FieldErrorUnexpectedNull fromField =
| FieldErrorInvalidField Oid Text String parseOnly parseField
deriving (Eq, Show)
mapLeft :: (b -> c) -> Either b a -> Either c a
mapLeft f (Left l) = Left $ f l
mapLeft _ (Right r) = Right r
fromField :: FromField a => Oid -> Text -> Either FieldError a
fromField oid field =
mapLeft (FieldErrorInvalidField oid field) $ parseOnly parseField field
class FromField a where class FromField a where
validOid :: Proxy a -> Oid -> Bool validOid :: Proxy a -> Oid -> Bool

View File

@ -47,16 +47,16 @@ spec = do
it "Gets the column table for a result" $ \conn -> do it "Gets the column table for a result" $ \conn -> do
Just result <- LibPQ.execParams conn "SELECT name, age FROM person" [] LibPQ.Text Just result <- LibPQ.execParams conn "SELECT name, age FROM person" [] LibPQ.Text
columnTable <- Opium.getColumnTable @Person Proxy result columnTable <- Opium.getColumnTable @Person Proxy result
columnTable `shouldBe` Right [0, 1] fmap (map fst) columnTable `shouldBe` Right [0, 1]
it "Gets the numbers right for funky configurations" $ \conn -> do it "Gets the numbers right for funky configurations" $ \conn -> do
Just result0 <- LibPQ.execParams conn "SELECT age, name FROM person" [] LibPQ.Text Just result0 <- LibPQ.execParams conn "SELECT age, name FROM person" [] LibPQ.Text
columnTable0 <- Opium.getColumnTable @Person Proxy result0 columnTable0 <- Opium.getColumnTable @Person Proxy result0
columnTable0 `shouldBe` Right [1, 0] fmap (map fst) columnTable0 `shouldBe` Right [1, 0]
Just result1 <- LibPQ.execParams conn "SELECT 0 AS a, 1 AS b, 2 AS c, age, 4 AS d, name FROM person" [] LibPQ.Text Just result1 <- LibPQ.execParams conn "SELECT 0 AS a, 1 AS b, 2 AS c, age, 4 AS d, name FROM person" [] LibPQ.Text
columnTable1 <- Opium.getColumnTable @Person Proxy result1 columnTable1 <- Opium.getColumnTable @Person Proxy result1
columnTable1 `shouldBe` Right [5, 3] fmap (map fst) columnTable1 `shouldBe` Right [5, 3]
it "Fails for missing columns" $ \conn -> do it "Fails for missing columns" $ \conn -> do
Just result <- LibPQ.execParams conn "SELECT 0 AS a FROM person" [] LibPQ.Text Just result <- LibPQ.execParams conn "SELECT 0 AS a FROM person" [] LibPQ.Text
@ -68,31 +68,31 @@ spec = do
Just result <- LibPQ.execParams conn "SELECT * FROM person" [] LibPQ.Text Just result <- LibPQ.execParams conn "SELECT * FROM person" [] LibPQ.Text
Right columnTable <- Opium.getColumnTable @Person Proxy result Right columnTable <- Opium.getColumnTable @Person Proxy result
row0 <- Opium.fromRow @Person result columnTable (LibPQ.Row 0) row0 <- Opium.fromRow @Person result columnTable 0
row0 `shouldBe` Right (Person "paul" 25) row0 `shouldBe` Right (Person "paul" 25)
row1 <- Opium.fromRow @Person result columnTable (LibPQ.Row 1) row1 <- Opium.fromRow @Person result columnTable 1
row1 `shouldBe` Right (Person "albus" 103) row1 `shouldBe` Right (Person "albus" 103)
it "Decodes NULL into Nothing for Maybes" $ \conn -> do it "Decodes NULL into Nothing for Maybes" $ \conn -> do
Just result <- LibPQ.execParams conn "SELECT NULL AS a" [] LibPQ.Text Just result <- LibPQ.execParams conn "SELECT NULL AS a" [] LibPQ.Text
Right columnTable <- Opium.getColumnTable @MaybeTest Proxy result Right columnTable <- Opium.getColumnTable @MaybeTest Proxy result
row <- Opium.fromRow result columnTable (LibPQ.Row 0) row <- Opium.fromRow result columnTable 0
row `shouldBe` Right (MaybeTest Nothing) row `shouldBe` Right (MaybeTest Nothing)
it "Decodes values into Just for Maybes" $ \conn -> do it "Decodes values into Just for Maybes" $ \conn -> do
Just result <- LibPQ.execParams conn "SELECT 'abc' AS a" [] LibPQ.Text Just result <- LibPQ.execParams conn "SELECT 'abc' AS a" [] LibPQ.Text
Right columnTable <- Opium.getColumnTable @MaybeTest Proxy result Right columnTable <- Opium.getColumnTable @MaybeTest Proxy result
row <- Opium.fromRow result columnTable (LibPQ.Row 0) row <- Opium.fromRow result columnTable 0
row `shouldBe` Right (MaybeTest $ Just "abc") row `shouldBe` Right (MaybeTest $ Just "abc")
it "Works for many fields" $ \conn -> do it "Works for many fields" $ \conn -> do
Just result <- LibPQ.execParams conn "SELECT 'abc' AS a, 42 AS b, 1.0::double precision AS c, 'test' AS d, true AS e" [] LibPQ.Text Just result <- LibPQ.execParams conn "SELECT 'abc' AS a, 42 AS b, 1.0::double precision AS c, 'test' AS d, true AS e" [] LibPQ.Text
Right columnTable <- Opium.getColumnTable @ManyFields Proxy result Right columnTable <- Opium.getColumnTable @ManyFields Proxy result
row <- Opium.fromRow result columnTable (LibPQ.Row 0) row <- Opium.fromRow result columnTable 0
row `shouldBe` Right (ManyFields "abc" 42 1.0 "test" True) row `shouldBe` Right (ManyFields "abc" 42 1.0 "test" True)
describe "fetch_" $ do describe "fetch_" $ do
@ -106,7 +106,7 @@ spec = do
it "Fails for unexpected NULLs" $ \conn -> do it "Fails for unexpected NULLs" $ \conn -> do
rows <- Opium.fetch_ @Person conn "SELECT NULL AS name, 0 AS age" rows <- Opium.fetch_ @Person conn "SELECT NULL AS name, 0 AS age"
rows `shouldBe` Left (Opium.ErrorUnexpectedNull (LibPQ.Row 0) "name") rows `shouldBe` Left (Opium.ErrorUnexpectedNull (Opium.ErrorPosition 0 "name"))
it "Fails for the wrong column type" $ \conn -> do it "Fails for the wrong column type" $ \conn -> do
rows <- Opium.fetch_ @Person conn "SELECT 'quby' AS name, 'indeterminate' AS age" rows <- Opium.fetch_ @Person conn "SELECT 'quby' AS name, 'indeterminate' AS age"