169 lines
6.2 KiB
Haskell

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Database.PostgreSQL.Opium
( Error (..)
, ErrorPosition (..)
, FromField (..)
, FromRow (..)
, fetch_
)
where
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT)
import Data.ByteString (ByteString)
import Data.IORef (IORef, modifyIORef', newIORef, readIORef)
import Data.Proxy (Proxy (..))
import Data.Text (Text)
import Database.PostgreSQL.LibPQ
( Column
, Connection
, Oid
, Result
, Row
)
import GHC.Generics (C, D, Generic, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..))
import GHC.TypeLits (KnownSymbol, symbolVal)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Encoding
import qualified Database.PostgreSQL.LibPQ as LibPQ
import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..))
import Database.PostgreSQL.Opium.FromField (FromField (..), fromField)
execParams :: Connection -> ByteString -> ExceptT Error IO Result
execParams conn query = do
liftIO (LibPQ.execParams conn query [] LibPQ.Text) >>= \case
Nothing ->
except $ Left ErrorNoResult
Just result -> do
status <- liftIO $ LibPQ.resultStatus result
mbMessage <- liftIO $ LibPQ.resultErrorMessage result
case mbMessage of
Just "" -> pure result
Nothing -> pure result
Just message -> except $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message
fetch_ :: forall a. FromRow a => Connection -> ByteString -> IO (Either Error [a])
fetch_ conn query = runExceptT $ do
result <- execParams conn query
-- TODO: Use unboxed array for columnTable
columnTable <- ExceptT $ getColumnTable @a Proxy result
nRows <- liftIO $ LibPQ.ntuples result
mapM (ExceptT . fromRow result columnTable) [0..nRows - 1]
type ColumnTable = [(Column, Oid)]
indexColumnTable :: ColumnTable -> Int -> (Column, Oid)
indexColumnTable = (!!)
class FromRow a where
getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable)
default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable)
getColumnTable Proxy = runExceptT . getColumnTable' @(Rep a) Proxy
fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a)
default fromRow :: (Generic a, FromRow' (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a)
fromRow result columnTable row = do
iRef <- newIORef 0
runExceptT $ to <$> fromRow' (FromRowCtx result columnTable iRef) row
class GetColumnTable' f where
getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO ColumnTable
instance GetColumnTable' f => GetColumnTable' (M1 D c f) where
getColumnTable' Proxy = getColumnTable' @f Proxy
instance GetColumnTable' f => GetColumnTable' (M1 C c f) where
getColumnTable' Proxy = getColumnTable' @f Proxy
instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where
getColumnTable' Proxy result =
(++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result
checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO ColumnTable
checkColumn Proxy nameStr result = do
column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name
oid <- liftIO $ LibPQ.ftype result column
if validOid @a Proxy oid then
pure [(column, oid)]
else
except $ Left $ ErrorInvalidOid nameText oid
where
nameText = Text.pack nameStr
name = Encoding.encodeUtf8 nameText
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy
data FromRowCtx = FromRowCtx Result ColumnTable (IORef Int)
class FromRow' f where
fromRow' :: FromRowCtx -> Row -> ExceptT Error IO (f p)
instance FromRow' f => FromRow' (M1 D c f) where
fromRow' ctx row = M1 <$> fromRow' ctx row
instance FromRow' f => FromRow' (M1 C c f) where
fromRow' ctx row = M1 <$> fromRow' ctx row
instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where
fromRow' ctx@(FromRowCtx _ _ iRef) row = do
y <- fromRow' ctx row
liftIO $ modifyIORef' iRef (+1)
z <- fromRow' ctx row
pure $ y :*: z
decodeField
:: FromField t
=> Text
-> (Row -> Maybe t -> Either Error t')
-> FromRowCtx
-> Row
-> ExceptT Error IO (M1 S m (Rec0 t') p)
decodeField nameText g (FromRowCtx result columnTable iRef) row = do
i <- liftIO $ readIORef iRef
let (column, oid) = columnTable `indexColumnTable` i
mbField <- liftIO $ getFieldText column
mbValue <- except $ getValue oid mbField
value <- except $ g row mbValue
pure $ M1 $ K1 value
where
getFieldText :: Column -> IO (Maybe Text)
getFieldText column =
fmap Encoding.decodeUtf8 <$> LibPQ.getvalue result row column
getValue :: FromField u => LibPQ.Oid -> Maybe Text -> Either Error (Maybe u)
getValue oid = maybe (Right Nothing) $ \fieldText ->
mapLeft
(ErrorInvalidField (ErrorPosition row nameText) oid fieldText)
(Just <$> fromField fieldText)
mapLeft :: (b -> c) -> Either b a -> Either c a
mapLeft f (Left l) = Left $ f l
mapLeft _ (Right r) = Right r
instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
fromRow' = decodeField nameText $ \row ->
maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right
where
nameText = Text.pack $ symbolVal @nameSym Proxy
instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
fromRow' = decodeField nameText $ const pure
where
nameText = Text.pack $ symbolVal @nameSym Proxy