192 lines
6.8 KiB
Haskell

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Database.PostgreSQL.Opium
( ColumnTable
, Error (..)
, ErrorPosition (..)
, FromField (..)
, FromRow (..)
, RawField (..)
, fetch
, fetch_
, toListColumnTable
)
where
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT)
import Data.ByteString (ByteString)
import Data.IORef (IORef, modifyIORef', newIORef, readIORef)
import Data.Proxy (Proxy (..))
import Data.Text (Text)
import Data.Vector (Vector)
import Database.PostgreSQL.LibPQ
( Column
, Connection
, Oid
, Result
, Row
)
import GHC.Generics (C, D, Generic, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..))
import GHC.TypeLits (KnownSymbol, symbolVal)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Encoding
import qualified Data.Vector as Vector
import qualified Database.PostgreSQL.LibPQ as LibPQ
import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..))
import Database.PostgreSQL.Opium.FromField (FromField (..), fromField, RawField (..))
import Database.PostgreSQL.Opium.ToParamList (ToParamList (..))
execParams
:: ToParamList a
=> Connection
-> Text
-> a
-> ExceptT Error IO Result
execParams conn query params = do
let queryBytes = Encoding.encodeUtf8 query
liftIO (LibPQ.execParams conn queryBytes (toParamList params) LibPQ.Binary) >>= \case
Nothing ->
except $ Left ErrorNoResult
Just result -> do
status <- liftIO $ LibPQ.resultStatus result
mbMessage <- liftIO $ LibPQ.resultErrorMessage result
case mbMessage of
Just "" -> pure result
Nothing -> pure result
Just message -> except $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message
fetch
:: forall a b. (ToParamList a, FromRow b)
=> Connection
-> Text
-> a
-> IO (Either Error [b])
fetch conn query params = runExceptT $ do
result <- execParams conn query params
columnTable <- ExceptT $ getColumnTable @b Proxy result
nRows <- liftIO $ LibPQ.ntuples result
mapM (ExceptT . fromRow result columnTable) [0..nRows - 1]
fetch_ :: forall a. FromRow a => Connection -> Text -> IO (Either Error [a])
fetch_ conn query = fetch conn query ()
newtype ColumnTable = ColumnTable (Vector (Column, Oid))
deriving (Eq, Show)
newColumnTable :: [(Column, Oid)] -> ColumnTable
newColumnTable = ColumnTable . Vector.fromList
indexColumnTable :: ColumnTable -> Int -> (Column, Oid)
indexColumnTable (ColumnTable v) i = v `Vector.unsafeIndex` i
toListColumnTable :: ColumnTable -> [(Column, Oid)]
toListColumnTable (ColumnTable v) = Vector.toList v
class FromRow a where
getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable)
default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable)
getColumnTable Proxy = runExceptT . fmap newColumnTable . getColumnTable' @(Rep a) Proxy
fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a)
default fromRow :: (Generic a, FromRow' (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a)
fromRow result columnTable row = do
iRef <- newIORef 0
runExceptT $ to <$> fromRow' (FromRowCtx result columnTable iRef) row
class GetColumnTable' f where
getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [(Column, Oid)]
instance GetColumnTable' f => GetColumnTable' (M1 D c f) where
getColumnTable' Proxy = getColumnTable' @f Proxy
instance GetColumnTable' f => GetColumnTable' (M1 C c f) where
getColumnTable' Proxy = getColumnTable' @f Proxy
instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where
getColumnTable' Proxy result =
(++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result
checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [(Column, Oid)]
checkColumn Proxy nameStr result = do
column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name
oid <- liftIO $ LibPQ.ftype result column
if validOid @a Proxy oid then
pure [(column, oid)]
else
except $ Left $ ErrorInvalidOid nameText oid
where
nameText = Text.pack nameStr
name = Encoding.encodeUtf8 nameText
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
data FromRowCtx = FromRowCtx Result ColumnTable (IORef Int)
class FromRow' f where
fromRow' :: FromRowCtx -> Row -> ExceptT Error IO (f p)
instance FromRow' f => FromRow' (M1 D c f) where
fromRow' ctx row = M1 <$> fromRow' ctx row
instance FromRow' f => FromRow' (M1 C c f) where
fromRow' ctx row = M1 <$> fromRow' ctx row
instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where
fromRow' ctx row = do
y <- fromRow' ctx row
z <- fromRow' ctx row
pure $ y :*: z
decodeField
:: FromField t
=> Text
-> (Row -> Maybe t -> Either Error t')
-> FromRowCtx
-> Row
-> ExceptT Error IO (M1 S m (Rec0 t') p)
decodeField nameText g (FromRowCtx result columnTable iRef) row = do
i <- liftIO $ readIORef iRef
liftIO $ modifyIORef' iRef (+1)
let (column, oid) = columnTable `indexColumnTable` i
mbField <- liftIO $ LibPQ.getvalue result row column
mbValue <- except $ getValue oid mbField
value <- except $ g row mbValue
pure $ M1 $ K1 value
where
getValue :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u)
getValue oid = maybe (Right Nothing) $ \field ->
mapLeft
(ErrorInvalidField (ErrorPosition row nameText) oid field)
(Just <$> fromField field)
mapLeft :: (b -> c) -> Either b a -> Either c a
mapLeft f (Left l) = Left $ f l
mapLeft _ (Right r) = Right r
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
fromRow' = decodeField nameText $ \row ->
maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right
where
nameText = Text.pack $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
fromRow' = decodeField nameText $ const pure
where
nameText = Text.pack $ symbolVal @nameSym Proxy