164 lines
6.1 KiB
Haskell

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Database.PostgreSQL.Opium
( Error (..)
, FieldError (..)
, FromField (..)
, FromRow (..)
, fetch_
)
where
import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT)
import Control.Monad.Trans.State (StateT (..), evalStateT, modify)
import Data.ByteString (ByteString)
import Data.Proxy (Proxy (..))
import Data.Text (Text)
import Database.PostgreSQL.LibPQ
( Column
, Connection
, Result
, Row
)
import GHC.Generics (C, D, Generic, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..))
import GHC.TypeLits (KnownSymbol, symbolVal)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Encoding
import qualified Database.PostgreSQL.LibPQ as LibPQ
import Database.PostgreSQL.Opium.Error (Error (..))
import Database.PostgreSQL.Opium.FromField (FieldError (..), FromField (..), fromField)
execParams :: Connection -> ByteString -> IO (Either Error Result)
execParams conn query = do
LibPQ.execParams conn query [] LibPQ.Text >>= \case
Nothing ->
pure $ Left ErrorNoResult
Just result -> do
status <- LibPQ.resultStatus result
mbMessage <- LibPQ.resultErrorMessage result
case mbMessage of
Just "" -> pure $ Right result
Just message -> pure $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message
Nothing -> pure $ Right result
fetch_ :: forall a. FromRow a => Connection -> ByteString -> IO (Either Error [a])
fetch_ conn query = runExceptT $ do
result <- ExceptT $ execParams conn query
-- TODO: Use unboxed array for columnTable
columnTable <- ExceptT $ getColumnTable @a Proxy result
nRows <- ExceptT $ Right <$> LibPQ.ntuples result
mapM (ExceptT . fromRow result columnTable) [0..nRows - 1]
type ColumnTable = [Column]
class FromRow a where
getColumnTable :: Proxy a -> Result -> IO (Either Error [Column])
default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error [Column])
getColumnTable Proxy = runExceptT . getColumnTable' @(Rep a) Proxy
fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a)
default fromRow :: (Generic a, FromRow' (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a)
fromRow result columnTable row = evalStateT (fmap to <$> fromRow' result columnTable row) 0
class GetColumnTable' f where
getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [Column]
instance GetColumnTable' f => GetColumnTable' (M1 D c f) where
getColumnTable' Proxy = getColumnTable' @f Proxy
instance GetColumnTable' f => GetColumnTable' (M1 C c f) where
getColumnTable' Proxy = getColumnTable' @f Proxy
instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where
getColumnTable' Proxy result =
(++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result
checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [Column]
checkColumn Proxy nameStr result = do
column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name
oid <- ExceptT $ Right <$> LibPQ.ftype result column
if validOid @a Proxy oid then
pure [column]
else
except $ Left $ ErrorInvalidOid nameText oid
where
nameText = Text.pack nameStr
name = Encoding.encodeUtf8 nameText
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
class FromRow' f where
fromRow' :: Result -> ColumnTable -> Row -> StateT Int IO (Either Error (f p))
instance FromRow' f => FromRow' (M1 D c f) where
fromRow' result columnTable row = fmap M1 <$> fromRow' result columnTable row
instance FromRow' f => FromRow' (M1 C c f) where
fromRow' result columnTable row = fmap M1 <$> fromRow' result columnTable row
instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where
fromRow' result columnTable row = do
y <- fromRow' result columnTable row
modify (+1)
z <- fromRow' result columnTable row
pure $ (:*:) <$> y <*> z
decodeField
:: FromField t
=> Text
-> (Row -> Maybe t -> Either Error t')
-> Result
-> ColumnTable
-> Row
-> StateT Int IO (Either Error (M1 S m (Rec0 t') p))
decodeField nameText g result columnTable row = StateT $ \i -> do
v <- runExceptT $ do
let column = columnTable !! i
oid <- ExceptT $ pure <$> LibPQ.ftype result column
mbField <- getFieldText column
mbValue <- getValue oid mbField
value <- except $ g row mbValue
pure $ M1 $ K1 value
pure (v, i)
where
getFieldText :: Column -> ExceptT Error IO (Maybe Text)
getFieldText column =
ExceptT $ Right . fmap Encoding.decodeUtf8 <$> LibPQ.getvalue result row column
getValue :: FromField u => LibPQ.Oid -> Maybe Text -> ExceptT Error IO (Maybe u)
getValue oid = except . maybe
(Right Nothing)
(fmap Just . mapLeft (ErrorDecode row nameText) . fromField oid)
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
fromRow' = decodeField nameText $ \row -> maybe
(Left $ ErrorUnexpectedNull row nameText)
Right
where
nameText = Text.pack $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
fromRow' = decodeField nameText $ const Right
where
nameText = Text.pack $ symbolVal @nameSym Proxy
mapLeft :: (b -> c) -> Either b a -> Either c a
mapLeft f (Left l) = Left $ f l
mapLeft _ (Right r) = Right r