{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} module Database.PostgreSQL.Opium.FromRow -- * FromRow ( FromRow (..) -- * Internal , toListColumnTable ) where import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) import Data.ByteString (ByteString) import Data.Bifunctor (first) import Data.Kind (Type) import Data.Proxy (Proxy (..)) import Data.Text (Text) import Data.Vector (Vector) import Database.PostgreSQL.LibPQ ( Column , Oid , Result , Row ) import GHC.Generics (Generic, C, D, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..)) import GHC.TypeLits (KnownNat, KnownSymbol, Nat, natVal, symbolVal, type (+)) import qualified Data.Text as Text import qualified Data.Text.Encoding as Encoding import qualified Data.Vector as Vector import qualified Database.PostgreSQL.LibPQ as LibPQ import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..)) import Database.PostgreSQL.Opium.FromField (FromField (..), fromField) class FromRow a where getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable) default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable) getColumnTable Proxy = runExceptT . fmap newColumnTable . getColumnTable' @(Rep a) Proxy fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a) default fromRow :: (Generic a, FromRow' 0 (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a) fromRow result columnTable row = runExceptT $ to <$> fromRow' @0 FRProxy (FromRowCtx result columnTable) row newtype ColumnTable = ColumnTable (Vector (Column, Oid)) deriving (Eq, Show) newColumnTable :: [(Column, Oid)] -> ColumnTable newColumnTable = ColumnTable . Vector.fromList indexColumnTable :: ColumnTable -> Int -> (Column, Oid) indexColumnTable (ColumnTable v) i = v `Vector.unsafeIndex` i toListColumnTable :: ColumnTable -> [(Column, Oid)] toListColumnTable (ColumnTable v) = Vector.toList v class GetColumnTable' f where getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [(Column, Oid)] instance GetColumnTable' f => GetColumnTable' (M1 D c f) where getColumnTable' Proxy = getColumnTable' @f Proxy instance GetColumnTable' f => GetColumnTable' (M1 C c f) where getColumnTable' Proxy = getColumnTable' @f Proxy instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where getColumnTable' Proxy result = (++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy -- | State kept for a call to 'fromRow'. data FromRowCtx = FromRowCtx Result -- ^ Obtained from 'LibPQ.execParams'. ColumnTable -- ^ 'Vector' of expected columns indices and OIDs. data FRProxy (n :: Nat) (f :: Type -> Type) = FRProxy class FromRow' (n :: Nat) (f :: Type -> Type) where type Members f :: Nat fromRow' :: FRProxy n f -> FromRowCtx -> Row -> ExceptT Error IO (f p) instance FromRow' n f => FromRow' n (M1 D c f) where type Members (M1 D c f) = Members f fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row instance FromRow' n f => FromRow' n (M1 C c f) where type Members (M1 C c f) = Members f fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row instance (FromRow' n f, FromRow' (n + Members f) g) => FromRow' n (f :*: g) where type Members (f :*: g) = Members f + Members g fromRow' FRProxy ctx row = (:*:) <$> fromRow' @n FRProxy ctx row <*> fromRow' @(n + Members f) FRProxy ctx row instance {-# OVERLAPPABLE #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where type Members (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) = 1 fromRow' FRProxy = decodeField memberIndex nameText $ \row -> maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right where memberIndex = fromIntegral $ natVal @n Proxy nameText = Text.pack $ symbolVal @nameSym Proxy instance {-# OVERLAPPING #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where type Members (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) = 1 fromRow' FRProxy = decodeField memberIndex nameText $ const pure where memberIndex = fromIntegral $ natVal @n Proxy nameText = Text.pack $ symbolVal @nameSym Proxy checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [(Column, Oid)] checkColumn Proxy nameStr result = do column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name oid <- liftIO $ LibPQ.ftype result column if validOid @a Proxy oid then pure [(column, oid)] else except $ Left $ ErrorInvalidOid nameText oid where nameText = Text.pack nameStr name = Encoding.encodeUtf8 nameText decodeField :: FromField t => Int -> Text -> (Row -> Maybe t -> Either Error t') -> FromRowCtx -> Row -> ExceptT Error IO (M1 S m (Rec0 t') p) decodeField memberIndex nameText g (FromRowCtx result columnTable) row = do let (column, oid) = columnTable `indexColumnTable` memberIndex mbField <- liftIO $ LibPQ.getvalue result row column mbValue <- except $ fromFieldIfPresent oid mbField value <- except $ g row mbValue pure $ M1 $ K1 value where fromFieldIfPresent :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u) fromFieldIfPresent oid = maybe (Right Nothing) $ \field -> first (ErrorInvalidField (ErrorPosition row nameText) oid field) (Just <$> fromField field)