163 lines
6.2 KiB
Haskell

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
module Database.PostgreSQL.Opium.FromRow
-- * FromRow
( FromRow (..)
-- * Internal
, toListColumnTable
) where
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT)
import Data.ByteString (ByteString)
import Data.Bifunctor (first)
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Text (Text)
import Data.Vector (Vector)
import Database.PostgreSQL.LibPQ
( Column
, Oid
, Result
, Row
)
import GHC.Generics (Generic, C, D, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..))
import GHC.TypeLits (KnownNat, KnownSymbol, Nat, natVal, symbolVal, type (+))
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Encoding
import qualified Data.Vector as Vector
import qualified Database.PostgreSQL.LibPQ as LibPQ
import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..))
import Database.PostgreSQL.Opium.FromField (FromField (..), fromField)
class FromRow a where
getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable)
default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable)
getColumnTable Proxy = runExceptT . fmap newColumnTable . getColumnTable' @(Rep a) Proxy
fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a)
default fromRow :: (Generic a, FromRow' 0 (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a)
fromRow result columnTable row =
runExceptT $ to <$> fromRow' @0 FRProxy (FromRowCtx result columnTable) row
newtype ColumnTable = ColumnTable (Vector (Column, Oid))
deriving (Eq, Show)
newColumnTable :: [(Column, Oid)] -> ColumnTable
newColumnTable = ColumnTable . Vector.fromList
indexColumnTable :: ColumnTable -> Int -> (Column, Oid)
indexColumnTable (ColumnTable v) i = v `Vector.unsafeIndex` i
toListColumnTable :: ColumnTable -> [(Column, Oid)]
toListColumnTable (ColumnTable v) = Vector.toList v
class GetColumnTable' f where
getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [(Column, Oid)]
instance GetColumnTable' f => GetColumnTable' (M1 D c f) where
getColumnTable' Proxy = getColumnTable' @f Proxy
instance GetColumnTable' f => GetColumnTable' (M1 C c f) where
getColumnTable' Proxy = getColumnTable' @f Proxy
instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where
getColumnTable' Proxy result =
(++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
-- | State kept for a call to 'fromRow'.
data FromRowCtx = FromRowCtx
Result -- ^ Obtained from 'LibPQ.execParams'.
ColumnTable -- ^ 'Vector' of expected columns indices and OIDs.
data FRProxy (n :: Nat) (f :: Type -> Type) = FRProxy
class FromRow' (n :: Nat) (f :: Type -> Type) where
type Members f :: Nat
fromRow' :: FRProxy n f -> FromRowCtx -> Row -> ExceptT Error IO (f p)
instance FromRow' n f => FromRow' n (M1 D c f) where
type Members (M1 D c f) = Members f
fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row
instance FromRow' n f => FromRow' n (M1 C c f) where
type Members (M1 C c f) = Members f
fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row
instance (FromRow' n f, FromRow' (n + Members f) g) => FromRow' n (f :*: g) where
type Members (f :*: g) = Members f + Members g
fromRow' FRProxy ctx row = (:*:) <$> fromRow' @n FRProxy ctx row <*> fromRow' @(n + Members f) FRProxy ctx row
instance {-# OVERLAPPABLE #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
type Members (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) = 1
fromRow' FRProxy = decodeField memberIndex nameText $ \row ->
maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right
where
memberIndex = fromIntegral $ natVal @n Proxy
nameText = Text.pack $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
type Members (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) = 1
fromRow' FRProxy = decodeField memberIndex nameText $ const pure
where
memberIndex = fromIntegral $ natVal @n Proxy
nameText = Text.pack $ symbolVal @nameSym Proxy
checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [(Column, Oid)]
checkColumn Proxy nameStr result = do
column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name
oid <- liftIO $ LibPQ.ftype result column
if validOid @a Proxy oid then
pure [(column, oid)]
else
except $ Left $ ErrorInvalidOid nameText oid
where
nameText = Text.pack nameStr
name = Encoding.encodeUtf8 nameText
decodeField
:: FromField t
=> Int
-> Text
-> (Row -> Maybe t -> Either Error t')
-> FromRowCtx
-> Row
-> ExceptT Error IO (M1 S m (Rec0 t') p)
decodeField memberIndex nameText g (FromRowCtx result columnTable) row = do
let (column, oid) = columnTable `indexColumnTable` memberIndex
mbField <- liftIO $ LibPQ.getvalue result row column
mbValue <- except $ fromFieldIfPresent oid mbField
value <- except $ g row mbValue
pure $ M1 $ K1 value
where
fromFieldIfPresent :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u)
fromFieldIfPresent oid = maybe (Right Nothing) $ \field ->
first
(ErrorInvalidField (ErrorPosition row nameText) oid field)
(Just <$> fromField field)