{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} module Database.PostgreSQL.Opium.FromRow -- * FromRow ( FromRow (..) -- * Internal , ColumnTable , toListColumnTable ) where import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) import Data.ByteString (ByteString) import Data.Bifunctor (first) import Data.Kind (Type) import Data.Proxy (Proxy (..)) import Data.Text (Text) import Data.Vector (Vector) import Database.PostgreSQL.LibPQ ( Column , Oid , Result , Row ) import GHC.Generics (Generic, C, D, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..)) import GHC.TypeLits (KnownNat, KnownSymbol, Nat, natVal, symbolVal, type (+)) import qualified Data.Text as Text import qualified Data.Text.Encoding as Encoding import qualified Data.Vector as Vector import qualified Database.PostgreSQL.LibPQ as LibPQ import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..)) import Database.PostgreSQL.Opium.FromField (FromField (..), fromField) class FromRow a where getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable) default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable) getColumnTable Proxy = runExceptT . fmap newColumnTable . getColumnTable' @(Rep a) Proxy fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a) default fromRow :: (Generic a, FromRow' 0 (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a) fromRow result columnTable row = runExceptT $ to <$> fromRow' @0 FRProxy (FromRowCtx result columnTable) row newtype ColumnTable = ColumnTable (Vector (Column, Oid)) deriving (Eq, Show) newColumnTable :: [(Column, Oid)] -> ColumnTable newColumnTable = ColumnTable . Vector.fromList indexColumnTable :: ColumnTable -> Int -> (Column, Oid) indexColumnTable (ColumnTable v) i = v `Vector.unsafeIndex` i toListColumnTable :: ColumnTable -> [(Column, Oid)] toListColumnTable (ColumnTable v) = Vector.toList v class GetColumnTable' f where getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [(Column, Oid)] instance GetColumnTable' f => GetColumnTable' (M1 D c f) where getColumnTable' Proxy = getColumnTable' @f Proxy instance GetColumnTable' f => GetColumnTable' (M1 C c f) where getColumnTable' Proxy = getColumnTable' @f Proxy instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where getColumnTable' Proxy result = (++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy -- | Number of members in the generic representation of a record type (doesn't support sum types). type family NumberOfMembers f where -- The data type itself has as many members as the type that it defines. NumberOfMembers (M1 D _ f) = NumberOfMembers f -- The constructor has as many members as the type that it contains. NumberOfMembers (M1 C _ f) = NumberOfMembers f -- A product type has as many members as its subtypes have together. NumberOfMembers (f :*: g) = NumberOfMembers f + NumberOfMembers g -- A selector has/is exactly one member. NumberOfMembers (M1 S _ f) = 1 -- | State kept for a call to 'fromRow'. data FromRowCtx = FromRowCtx Result -- ^ Obtained from 'LibPQ.execParams'. ColumnTable -- ^ 'Vector' of expected columns indices and OIDs. -- Specialized proxy type to be used instead of `Proxy (n, f)` data FRProxy (n :: Nat) (f :: Type -> Type) = FRProxy class FromRow' (n :: Nat) (f :: Type -> Type) where fromRow' :: FRProxy n f -> FromRowCtx -> Row -> ExceptT Error IO (f p) instance FromRow' n f => FromRow' n (M1 D c f) where fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row instance FromRow' n f => FromRow' n (M1 C c f) where fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row instance (FromRow' n f, FromRow' (n + NumberOfMembers f) g) => FromRow' n (f :*: g) where fromRow' FRProxy ctx row = (:*:) <$> fromRow' @n FRProxy ctx row <*> fromRow' @(n + NumberOfMembers f) FRProxy ctx row instance {-# OVERLAPPABLE #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where fromRow' FRProxy = decodeField memberIndex nameText $ \row -> maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right where memberIndex = fromIntegral $ natVal @n Proxy nameText = Text.pack $ symbolVal @nameSym Proxy instance {-# OVERLAPPING #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where fromRow' FRProxy = decodeField memberIndex nameText $ const pure where memberIndex = fromIntegral $ natVal @n Proxy nameText = Text.pack $ symbolVal @nameSym Proxy checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [(Column, Oid)] checkColumn Proxy nameStr result = do column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name oid <- liftIO $ LibPQ.ftype result column if validOid @a Proxy oid then pure [(column, oid)] else except $ Left $ ErrorInvalidOid nameText oid where nameText = Text.pack nameStr name = Encoding.encodeUtf8 nameText decodeField :: FromField t => Int -> Text -> (Row -> Maybe t -> Either Error t') -> FromRowCtx -> Row -> ExceptT Error IO (M1 S m (Rec0 t') p) decodeField memberIndex nameText g (FromRowCtx result columnTable) row = do let (column, oid) = columnTable `indexColumnTable` memberIndex mbField <- liftIO $ LibPQ.getvalue result row column mbValue <- except $ fromFieldIfPresent oid mbField value <- except $ g row mbValue pure $ M1 $ K1 value where fromFieldIfPresent :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u) fromFieldIfPresent oid = maybe (Right Nothing) $ \field -> first (ErrorInvalidField (ErrorPosition row nameText) oid field) (Just <$> fromField field)