Add some comments, execute and execute_

This commit is contained in:
Paul Brinkmeier 2024-06-10 10:26:40 +02:00
parent 20d150d12c
commit 301e20e7e8

View File

@ -10,20 +10,31 @@
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
module Database.PostgreSQL.Opium module Database.PostgreSQL.Opium
( ColumnTable -- * Queries
--
-- | TODO: Add @newtype Query = Query Text@ with @IsString@ instance to make constructing query strings at run time harder.
( fetch
, fetch_
, execute
, execute_
-- * Classes to Implement
, FromRow (..)
, FromField (..)
-- * Utility Stuff
, Error (..) , Error (..)
, ErrorPosition (..) , ErrorPosition (..)
, FromField (..)
, FromRow (..)
, RawField (..) , RawField (..)
, fetch -- * Exported for unit tests
, fetch_ --
-- | TODO: Don't export this from top-level module.
, ColumnTable
, toListColumnTable , toListColumnTable
) )
where where
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT)
import Data.Bifunctor (first)
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Data.IORef (IORef, modifyIORef', newIORef, readIORef) import Data.IORef (IORef, modifyIORef', newIORef, readIORef)
import Data.Proxy (Proxy (..)) import Data.Proxy (Proxy (..))
@ -83,6 +94,19 @@ fetch conn query params = runExceptT $ do
fetch_ :: forall a. FromRow a => Connection -> Text -> IO (Either Error [a]) fetch_ :: forall a. FromRow a => Connection -> Text -> IO (Either Error [a])
fetch_ conn query = fetch conn query () fetch_ conn query = fetch conn query ()
execute
:: forall a. ToParamList a
=> Connection
-> Text
-> a
-> IO (Either Error ())
execute conn query params = runExceptT $ do
_ <- execParams conn query params
pure ()
execute_ :: Connection -> Text -> IO (Either Error ())
execute_ conn query = execute conn query ()
newtype ColumnTable = ColumnTable (Vector (Column, Oid)) newtype ColumnTable = ColumnTable (Vector (Column, Oid))
deriving (Eq, Show) deriving (Eq, Show)
@ -137,7 +161,11 @@ instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTab
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
data FromRowCtx = FromRowCtx Result ColumnTable (IORef Int) -- | State kept for a call to 'fromRow'.
data FromRowCtx = FromRowCtx
Result -- ^ Obtained from 'LibPQ.execParams'.
ColumnTable -- ^ 'Vector' of expected columns indices and OIDs.
(IORef Int) -- ^ Index into 'ColumnTable', incremented after each column. TODO: Make this nicer.
class FromRow' f where class FromRow' f where
fromRow' :: FromRowCtx -> Row -> ExceptT Error IO (f p) fromRow' :: FromRowCtx -> Row -> ExceptT Error IO (f p)
@ -166,20 +194,16 @@ decodeField nameText g (FromRowCtx result columnTable iRef) row = do
liftIO $ modifyIORef' iRef (+1) liftIO $ modifyIORef' iRef (+1)
let (column, oid) = columnTable `indexColumnTable` i let (column, oid) = columnTable `indexColumnTable` i
mbField <- liftIO $ LibPQ.getvalue result row column mbField <- liftIO $ LibPQ.getvalue result row column
mbValue <- except $ getValue oid mbField mbValue <- except $ fromFieldIfPresent oid mbField
value <- except $ g row mbValue value <- except $ g row mbValue
pure $ M1 $ K1 value pure $ M1 $ K1 value
where where
getValue :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u) fromFieldIfPresent :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u)
getValue oid = maybe (Right Nothing) $ \field -> fromFieldIfPresent oid = maybe (Right Nothing) $ \field ->
mapLeft first
(ErrorInvalidField (ErrorPosition row nameText) oid field) (ErrorInvalidField (ErrorPosition row nameText) oid field)
(Just <$> fromField field) (Just <$> fromField field)
mapLeft :: (b -> c) -> Either b a -> Either c a
mapLeft f (Left l) = Left $ f l
mapLeft _ (Right r) = Right r
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
fromRow' = decodeField nameText $ \row -> fromRow' = decodeField nameText $ \row ->
maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right