diff --git a/lib/Database/PostgreSQL/Opium.hs b/lib/Database/PostgreSQL/Opium.hs index 1cb2e2a..825f4e4 100644 --- a/lib/Database/PostgreSQL/Opium.hs +++ b/lib/Database/PostgreSQL/Opium.hs @@ -10,20 +10,31 @@ {-# LANGUAGE TypeOperators #-} module Database.PostgreSQL.Opium - ( ColumnTable + -- * Queries + -- + -- | TODO: Add @newtype Query = Query Text@ with @IsString@ instance to make constructing query strings at run time harder. + ( fetch + , fetch_ + , execute + , execute_ + -- * Classes to Implement + , FromRow (..) + , FromField (..) + -- * Utility Stuff , Error (..) , ErrorPosition (..) - , FromField (..) - , FromRow (..) , RawField (..) - , fetch - , fetch_ + -- * Exported for unit tests + -- + -- | TODO: Don't export this from top-level module. + , ColumnTable , toListColumnTable ) where import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) +import Data.Bifunctor (first) import Data.ByteString (ByteString) import Data.IORef (IORef, modifyIORef', newIORef, readIORef) import Data.Proxy (Proxy (..)) @@ -83,6 +94,19 @@ fetch conn query params = runExceptT $ do fetch_ :: forall a. FromRow a => Connection -> Text -> IO (Either Error [a]) fetch_ conn query = fetch conn query () +execute + :: forall a. ToParamList a + => Connection + -> Text + -> a + -> IO (Either Error ()) +execute conn query params = runExceptT $ do + _ <- execParams conn query params + pure () + +execute_ :: Connection -> Text -> IO (Either Error ()) +execute_ conn query = execute conn query () + newtype ColumnTable = ColumnTable (Vector (Column, Oid)) deriving (Eq, Show) @@ -137,7 +161,11 @@ instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTab instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy -data FromRowCtx = FromRowCtx Result ColumnTable (IORef Int) +-- | State kept for a call to 'fromRow'. +data FromRowCtx = FromRowCtx + Result -- ^ Obtained from 'LibPQ.execParams'. + ColumnTable -- ^ 'Vector' of expected columns indices and OIDs. + (IORef Int) -- ^ Index into 'ColumnTable', incremented after each column. TODO: Make this nicer. class FromRow' f where fromRow' :: FromRowCtx -> Row -> ExceptT Error IO (f p) @@ -166,20 +194,16 @@ decodeField nameText g (FromRowCtx result columnTable iRef) row = do liftIO $ modifyIORef' iRef (+1) let (column, oid) = columnTable `indexColumnTable` i mbField <- liftIO $ LibPQ.getvalue result row column - mbValue <- except $ getValue oid mbField + mbValue <- except $ fromFieldIfPresent oid mbField value <- except $ g row mbValue pure $ M1 $ K1 value where - getValue :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u) - getValue oid = maybe (Right Nothing) $ \field -> - mapLeft + fromFieldIfPresent :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u) + fromFieldIfPresent oid = maybe (Right Nothing) $ \field -> + first (ErrorInvalidField (ErrorPosition row nameText) oid field) (Just <$> fromField field) -mapLeft :: (b -> c) -> Either b a -> Either c a -mapLeft f (Left l) = Left $ f l -mapLeft _ (Right r) = Right r - instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where fromRow' = decodeField nameText $ \row -> maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right