From fea11b5f24cc738cadad235c14013c30fdab5769 Mon Sep 17 00:00:00 2001 From: Paul Brinkmeier Date: Mon, 10 Jun 2024 21:08:29 +0200 Subject: [PATCH] Remove ugly IORef from fromRow and replace it by ugly compile time calculation :) --- README.md | 5 + lib/Database/PostgreSQL/Opium.hs | 178 ++++------------------- lib/Database/PostgreSQL/Opium/FromRow.hs | 162 +++++++++++++++++++++ opium.cabal | 5 +- test/Database/PostgreSQL/OpiumSpec.hs | 3 +- 5 files changed, 200 insertions(+), 153 deletions(-) create mode 100644 lib/Database/PostgreSQL/Opium/FromRow.hs diff --git a/README.md b/README.md index 6dccfc4..1c114bd 100644 --- a/README.md +++ b/README.md @@ -74,4 +74,9 @@ getScoreByAge conn = do - [ ] Implement JSON decoding - [ ] Implement (anonymous) composite types - [ ] Catch [UnicodeException](https://hackage.haskell.org/package/text-2.1/docs/Data-Text-Encoding-Error.html#t:UnicodeException) when decoding text + - This might not be necessary if Postgres guarantees us that having a textual OID on a field means that the field is encoded correctly. - [ ] Implement array decoding +- [ ] Better docs and structure for `FromRow` module +- [ ] Lexer for PostgreSQL that replaces $name by $1, $2, etc. +- [ ] Tutorial +- [ ] Rationale diff --git a/lib/Database/PostgreSQL/Opium.hs b/lib/Database/PostgreSQL/Opium.hs index 825f4e4..c2273c1 100644 --- a/lib/Database/PostgreSQL/Opium.hs +++ b/lib/Database/PostgreSQL/Opium.hs @@ -1,5 +1,4 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} @@ -7,7 +6,6 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} module Database.PostgreSQL.Opium -- * Queries @@ -20,74 +18,43 @@ module Database.PostgreSQL.Opium -- * Classes to Implement , FromRow (..) , FromField (..) + , ToParamList (..) + , ToField (..) -- * Utility Stuff , Error (..) , ErrorPosition (..) , RawField (..) - -- * Exported for unit tests - -- - -- | TODO: Don't export this from top-level module. - , ColumnTable - , toListColumnTable ) where import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) -import Data.Bifunctor (first) -import Data.ByteString (ByteString) -import Data.IORef (IORef, modifyIORef', newIORef, readIORef) import Data.Proxy (Proxy (..)) import Data.Text (Text) -import Data.Vector (Vector) import Database.PostgreSQL.LibPQ - ( Column - , Connection - , Oid + ( Connection , Result - , Row ) -import GHC.Generics (C, D, Generic, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..)) -import GHC.TypeLits (KnownSymbol, symbolVal) -import qualified Data.Text as Text import qualified Data.Text.Encoding as Encoding -import qualified Data.Vector as Vector import qualified Database.PostgreSQL.LibPQ as LibPQ import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..)) -import Database.PostgreSQL.Opium.FromField (FromField (..), fromField, RawField (..)) +import Database.PostgreSQL.Opium.FromField (FromField (..), RawField (..)) +import Database.PostgreSQL.Opium.FromRow (FromRow (..)) +import Database.PostgreSQL.Opium.ToField (ToField (..)) import Database.PostgreSQL.Opium.ToParamList (ToParamList (..)) -execParams - :: ToParamList a - => Connection - -> Text - -> a - -> ExceptT Error IO Result -execParams conn query params = do - let queryBytes = Encoding.encodeUtf8 query - liftIO (LibPQ.execParams conn queryBytes (toParamList params) LibPQ.Binary) >>= \case - Nothing -> - except $ Left ErrorNoResult - Just result -> do - status <- liftIO $ LibPQ.resultStatus result - mbMessage <- liftIO $ LibPQ.resultErrorMessage result - case mbMessage of - Just "" -> pure result - Nothing -> pure result - Just message -> except $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message - -- The order of the type parameters is important, because it is more common to use type applications for providing the row type. fetch - :: forall b a. (ToParamList a, FromRow b) + :: forall a b. (ToParamList b, FromRow a) => Connection -> Text - -> a - -> IO (Either Error [b]) + -> b + -> IO (Either Error [a]) fetch conn query params = runExceptT $ do result <- execParams conn query params - columnTable <- ExceptT $ getColumnTable @b Proxy result + columnTable <- ExceptT $ getColumnTable @a Proxy result nRows <- liftIO $ LibPQ.ntuples result mapM (ExceptT . fromRow result columnTable) [0..nRows - 1] @@ -107,110 +74,21 @@ execute conn query params = runExceptT $ do execute_ :: Connection -> Text -> IO (Either Error ()) execute_ conn query = execute conn query () -newtype ColumnTable = ColumnTable (Vector (Column, Oid)) - deriving (Eq, Show) - -newColumnTable :: [(Column, Oid)] -> ColumnTable -newColumnTable = ColumnTable . Vector.fromList - -indexColumnTable :: ColumnTable -> Int -> (Column, Oid) -indexColumnTable (ColumnTable v) i = v `Vector.unsafeIndex` i - -toListColumnTable :: ColumnTable -> [(Column, Oid)] -toListColumnTable (ColumnTable v) = Vector.toList v - -class FromRow a where - getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable) - default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable) - getColumnTable Proxy = runExceptT . fmap newColumnTable . getColumnTable' @(Rep a) Proxy - - fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a) - default fromRow :: (Generic a, FromRow' (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a) - fromRow result columnTable row = do - iRef <- newIORef 0 - runExceptT $ to <$> fromRow' (FromRowCtx result columnTable iRef) row - -class GetColumnTable' f where - getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [(Column, Oid)] - -instance GetColumnTable' f => GetColumnTable' (M1 D c f) where - getColumnTable' Proxy = getColumnTable' @f Proxy - -instance GetColumnTable' f => GetColumnTable' (M1 C c f) where - getColumnTable' Proxy = getColumnTable' @f Proxy - -instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where - getColumnTable' Proxy result = - (++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result - -checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [(Column, Oid)] -checkColumn Proxy nameStr result = do - column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name - oid <- liftIO $ LibPQ.ftype result column - if validOid @a Proxy oid then - pure [(column, oid)] - else - except $ Left $ ErrorInvalidOid nameText oid - where - nameText = Text.pack nameStr - name = Encoding.encodeUtf8 nameText - -instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where - getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy - -instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where - getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy - --- | State kept for a call to 'fromRow'. -data FromRowCtx = FromRowCtx - Result -- ^ Obtained from 'LibPQ.execParams'. - ColumnTable -- ^ 'Vector' of expected columns indices and OIDs. - (IORef Int) -- ^ Index into 'ColumnTable', incremented after each column. TODO: Make this nicer. - -class FromRow' f where - fromRow' :: FromRowCtx -> Row -> ExceptT Error IO (f p) - -instance FromRow' f => FromRow' (M1 D c f) where - fromRow' ctx row = M1 <$> fromRow' ctx row - -instance FromRow' f => FromRow' (M1 C c f) where - fromRow' ctx row = M1 <$> fromRow' ctx row - -instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where - fromRow' ctx row = do - y <- fromRow' ctx row - z <- fromRow' ctx row - pure $ y :*: z - -decodeField - :: FromField t - => Text - -> (Row -> Maybe t -> Either Error t') - -> FromRowCtx - -> Row - -> ExceptT Error IO (M1 S m (Rec0 t') p) -decodeField nameText g (FromRowCtx result columnTable iRef) row = do - i <- liftIO $ readIORef iRef - liftIO $ modifyIORef' iRef (+1) - let (column, oid) = columnTable `indexColumnTable` i - mbField <- liftIO $ LibPQ.getvalue result row column - mbValue <- except $ fromFieldIfPresent oid mbField - value <- except $ g row mbValue - pure $ M1 $ K1 value - where - fromFieldIfPresent :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u) - fromFieldIfPresent oid = maybe (Right Nothing) $ \field -> - first - (ErrorInvalidField (ErrorPosition row nameText) oid field) - (Just <$> fromField field) - -instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where - fromRow' = decodeField nameText $ \row -> - maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right - where - nameText = Text.pack $ symbolVal @nameSym Proxy - -instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where - fromRow' = decodeField nameText $ const pure - where - nameText = Text.pack $ symbolVal @nameSym Proxy +execParams + :: ToParamList a + => Connection + -> Text + -> a + -> ExceptT Error IO Result +execParams conn query params = do + let queryBytes = Encoding.encodeUtf8 query + liftIO (LibPQ.execParams conn queryBytes (toParamList params) LibPQ.Binary) >>= \case + Nothing -> + except $ Left ErrorNoResult + Just result -> do + status <- liftIO $ LibPQ.resultStatus result + mbMessage <- liftIO $ LibPQ.resultErrorMessage result + case mbMessage of + Just "" -> pure result + Nothing -> pure result + Just message -> except $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message diff --git a/lib/Database/PostgreSQL/Opium/FromRow.hs b/lib/Database/PostgreSQL/Opium/FromRow.hs new file mode 100644 index 0000000..dc3c41b --- /dev/null +++ b/lib/Database/PostgreSQL/Opium/FromRow.hs @@ -0,0 +1,162 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE UndecidableInstances #-} + +module Database.PostgreSQL.Opium.FromRow + -- * FromRow + ( FromRow (..) + -- * Internal + , toListColumnTable + ) where + +import Control.Monad.IO.Class (liftIO) +import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) +import Data.ByteString (ByteString) +import Data.Bifunctor (first) +import Data.Kind (Type) +import Data.Proxy (Proxy (..)) +import Data.Text (Text) +import Data.Vector (Vector) +import Database.PostgreSQL.LibPQ + ( Column + , Oid + , Result + , Row + ) +import GHC.Generics (Generic, C, D, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..)) +import GHC.TypeLits (KnownNat, KnownSymbol, Nat, natVal, symbolVal, type (+)) + +import qualified Data.Text as Text +import qualified Data.Text.Encoding as Encoding +import qualified Data.Vector as Vector +import qualified Database.PostgreSQL.LibPQ as LibPQ + +import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..)) +import Database.PostgreSQL.Opium.FromField (FromField (..), fromField) + +class FromRow a where + getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable) + default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable) + getColumnTable Proxy = runExceptT . fmap newColumnTable . getColumnTable' @(Rep a) Proxy + + fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a) + default fromRow :: (Generic a, FromRow' 0 (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a) + fromRow result columnTable row = + runExceptT $ to <$> fromRow' @0 FRProxy (FromRowCtx result columnTable) row + +newtype ColumnTable = ColumnTable (Vector (Column, Oid)) + deriving (Eq, Show) + +newColumnTable :: [(Column, Oid)] -> ColumnTable +newColumnTable = ColumnTable . Vector.fromList + +indexColumnTable :: ColumnTable -> Int -> (Column, Oid) +indexColumnTable (ColumnTable v) i = v `Vector.unsafeIndex` i + +toListColumnTable :: ColumnTable -> [(Column, Oid)] +toListColumnTable (ColumnTable v) = Vector.toList v + +class GetColumnTable' f where + getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [(Column, Oid)] + +instance GetColumnTable' f => GetColumnTable' (M1 D c f) where + getColumnTable' Proxy = getColumnTable' @f Proxy + +instance GetColumnTable' f => GetColumnTable' (M1 C c f) where + getColumnTable' Proxy = getColumnTable' @f Proxy + +instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where + getColumnTable' Proxy result = + (++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result + +instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where + getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy + +instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where + getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy + +-- | State kept for a call to 'fromRow'. +data FromRowCtx = FromRowCtx + Result -- ^ Obtained from 'LibPQ.execParams'. + ColumnTable -- ^ 'Vector' of expected columns indices and OIDs. + +data FRProxy (n :: Nat) (f :: Type -> Type) = FRProxy + +class FromRow' (n :: Nat) (f :: Type -> Type) where + type Members f :: Nat + + fromRow' :: FRProxy n f -> FromRowCtx -> Row -> ExceptT Error IO (f p) + +instance FromRow' n f => FromRow' n (M1 D c f) where + type Members (M1 D c f) = Members f + + fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row + +instance FromRow' n f => FromRow' n (M1 C c f) where + type Members (M1 C c f) = Members f + + fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row + +instance (FromRow' n f, FromRow' (n + Members f) g) => FromRow' n (f :*: g) where + type Members (f :*: g) = Members f + Members g + + fromRow' FRProxy ctx row = (:*:) <$> fromRow' @n FRProxy ctx row <*> fromRow' @(n + Members f) FRProxy ctx row + +instance {-# OVERLAPPABLE #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where + type Members (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) = 1 + + fromRow' FRProxy = decodeField memberIndex nameText $ \row -> + maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right + where + memberIndex = fromIntegral $ natVal @n Proxy + nameText = Text.pack $ symbolVal @nameSym Proxy + +instance {-# OVERLAPPING #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where + type Members (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) = 1 + + fromRow' FRProxy = decodeField memberIndex nameText $ const pure + where + memberIndex = fromIntegral $ natVal @n Proxy + nameText = Text.pack $ symbolVal @nameSym Proxy + +checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [(Column, Oid)] +checkColumn Proxy nameStr result = do + column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name + oid <- liftIO $ LibPQ.ftype result column + if validOid @a Proxy oid then + pure [(column, oid)] + else + except $ Left $ ErrorInvalidOid nameText oid + where + nameText = Text.pack nameStr + name = Encoding.encodeUtf8 nameText + +decodeField + :: FromField t + => Int + -> Text + -> (Row -> Maybe t -> Either Error t') + -> FromRowCtx + -> Row + -> ExceptT Error IO (M1 S m (Rec0 t') p) +decodeField memberIndex nameText g (FromRowCtx result columnTable) row = do + let (column, oid) = columnTable `indexColumnTable` memberIndex + mbField <- liftIO $ LibPQ.getvalue result row column + mbValue <- except $ fromFieldIfPresent oid mbField + value <- except $ g row mbValue + pure $ M1 $ K1 value + where + fromFieldIfPresent :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u) + fromFieldIfPresent oid = maybe (Right Nothing) $ \field -> + first + (ErrorInvalidField (ErrorPosition row nameText) oid field) + (Just <$> fromField field) + diff --git a/opium.cabal b/opium.cabal index 9ddd7aa..c553f65 100644 --- a/opium.cabal +++ b/opium.cabal @@ -62,13 +62,14 @@ library exposed-modules: Database.PostgreSQL.Opium, Database.PostgreSQL.Opium.FromField, + Database.PostgreSQL.Opium.FromRow, Database.PostgreSQL.Opium.ToField -- Modules included in this library but not exported. other-modules: Database.PostgreSQL.Opium.Error, - Database.PostgreSQL.Opium.ToParamList, - Database.PostgreSQL.Opium.Oid + Database.PostgreSQL.Opium.Oid, + Database.PostgreSQL.Opium.ToParamList -- LANGUAGE extensions used by modules in this package. -- other-extensions: diff --git a/test/Database/PostgreSQL/OpiumSpec.hs b/test/Database/PostgreSQL/OpiumSpec.hs index 50e0b34..0eb2a1e 100644 --- a/test/Database/PostgreSQL/OpiumSpec.hs +++ b/test/Database/PostgreSQL/OpiumSpec.hs @@ -17,6 +17,7 @@ import Test.Hspec (SpecWith, describe, it, shouldBe, shouldSatisfy) import qualified Database.PostgreSQL.LibPQ as LibPQ import qualified Database.PostgreSQL.Opium as Opium +import qualified Database.PostgreSQL.Opium.FromRow as Opium.FromRow data Person = Person { name :: Text @@ -64,7 +65,7 @@ shouldHaveColumns shouldHaveColumns proxy conn query expectedColumns = do Just result <- LibPQ.execParams conn query [] LibPQ.Binary columnTable <- Opium.getColumnTable proxy result - let actualColumns = fmap (map fst . Opium.toListColumnTable) columnTable + let actualColumns = fmap (map fst . Opium.FromRow.toListColumnTable) columnTable actualColumns `shouldBe` Right expectedColumns spec :: SpecWith Connection