{-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} module Database.PostgreSQL.Opium ( ColumnTable , Error (..) , ErrorPosition (..) , FromField (..) , FromRow (..) , fetch_ , toListColumnTable ) where import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) import Data.ByteString (ByteString) import Data.IORef (IORef, modifyIORef', newIORef, readIORef) import Data.Proxy (Proxy (..)) import Data.Text (Text) import Data.Vector (Vector) import Database.PostgreSQL.LibPQ ( Column , Connection , Oid , Result , Row ) import GHC.Generics (C, D, Generic, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..)) import GHC.TypeLits (KnownSymbol, symbolVal) import qualified Data.Text as Text import qualified Data.Text.Encoding as Encoding import qualified Data.Vector as Vector import qualified Database.PostgreSQL.LibPQ as LibPQ import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..)) import Database.PostgreSQL.Opium.FromField (FromField (..), fromField) execParams :: Connection -> ByteString -> ExceptT Error IO Result execParams conn query = do liftIO (LibPQ.execParams conn query [] LibPQ.Binary) >>= \case Nothing -> except $ Left ErrorNoResult Just result -> do status <- liftIO $ LibPQ.resultStatus result mbMessage <- liftIO $ LibPQ.resultErrorMessage result case mbMessage of Just "" -> pure result Nothing -> pure result Just message -> except $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message fetch_ :: forall a. FromRow a => Connection -> ByteString -> IO (Either Error [a]) fetch_ conn query = runExceptT $ do result <- execParams conn query columnTable <- ExceptT $ getColumnTable @a Proxy result nRows <- liftIO $ LibPQ.ntuples result mapM (ExceptT . fromRow result columnTable) [0..nRows - 1] newtype ColumnTable = ColumnTable (Vector (Column, Oid)) deriving (Eq, Show) newColumnTable :: [(Column, Oid)] -> ColumnTable newColumnTable = ColumnTable . Vector.fromList indexColumnTable :: ColumnTable -> Int -> (Column, Oid) indexColumnTable (ColumnTable v) i = v `Vector.unsafeIndex` i toListColumnTable :: ColumnTable -> [(Column, Oid)] toListColumnTable (ColumnTable v) = Vector.toList v class FromRow a where getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable) default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable) getColumnTable Proxy = runExceptT . fmap newColumnTable . getColumnTable' @(Rep a) Proxy fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a) default fromRow :: (Generic a, FromRow' (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a) fromRow result columnTable row = do iRef <- newIORef 0 runExceptT $ to <$> fromRow' (FromRowCtx result columnTable iRef) row class GetColumnTable' f where getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [(Column, Oid)] instance GetColumnTable' f => GetColumnTable' (M1 D c f) where getColumnTable' Proxy = getColumnTable' @f Proxy instance GetColumnTable' f => GetColumnTable' (M1 C c f) where getColumnTable' Proxy = getColumnTable' @f Proxy instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where getColumnTable' Proxy result = (++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [(Column, Oid)] checkColumn Proxy nameStr result = do column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name oid <- liftIO $ LibPQ.ftype result column if validOid @a Proxy oid then pure [(column, oid)] else except $ Left $ ErrorInvalidOid nameText oid where nameText = Text.pack nameStr name = Encoding.encodeUtf8 nameText instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy data FromRowCtx = FromRowCtx Result ColumnTable (IORef Int) class FromRow' f where fromRow' :: FromRowCtx -> Row -> ExceptT Error IO (f p) instance FromRow' f => FromRow' (M1 D c f) where fromRow' ctx row = M1 <$> fromRow' ctx row instance FromRow' f => FromRow' (M1 C c f) where fromRow' ctx row = M1 <$> fromRow' ctx row instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where fromRow' ctx row = do y <- fromRow' ctx row z <- fromRow' ctx row pure $ y :*: z decodeField :: FromField t => Text -> (Row -> Maybe t -> Either Error t') -> FromRowCtx -> Row -> ExceptT Error IO (M1 S m (Rec0 t') p) decodeField nameText g (FromRowCtx result columnTable iRef) row = do i <- liftIO $ readIORef iRef liftIO $ modifyIORef' iRef (+1) let (column, oid) = columnTable `indexColumnTable` i mbField <- liftIO $ LibPQ.getvalue result row column mbValue <- except $ getValue oid mbField value <- except $ g row mbValue pure $ M1 $ K1 value where getValue :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u) getValue oid = maybe (Right Nothing) $ \field -> mapLeft (ErrorInvalidField (ErrorPosition row nameText) oid field) (Just <$> fromField field) mapLeft :: (b -> c) -> Either b a -> Either c a mapLeft f (Left l) = Left $ f l mapLeft _ (Right r) = Right r instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where fromRow' = decodeField nameText $ \row -> maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right where nameText = Text.pack $ symbolVal @nameSym Proxy instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where fromRow' = decodeField nameText $ const pure where nameText = Text.pack $ symbolVal @nameSym Proxy