diff --git a/lib/Database/PostgreSQL/Opium.hs b/lib/Database/PostgreSQL/Opium.hs index de81a1b..6ab8142 100644 --- a/lib/Database/PostgreSQL/Opium.hs +++ b/lib/Database/PostgreSQL/Opium.hs @@ -16,6 +16,7 @@ module Database.PostgreSQL.Opium , FromField (..) , FromRow (..) , RawField (..) + , fetch , fetch_ , toListColumnTable ) @@ -45,11 +46,17 @@ import qualified Database.PostgreSQL.LibPQ as LibPQ import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..)) import Database.PostgreSQL.Opium.FromField (FromField (..), fromField, RawField (..)) +import Database.PostgreSQL.Opium.ToParamList (ToParamList (..)) -execParams :: Connection -> Text -> ExceptT Error IO Result -execParams conn query = do +execParams + :: ToParamList a + => Connection + -> Text + -> a + -> ExceptT Error IO Result +execParams conn query params = do let queryBytes = Encoding.encodeUtf8 query - liftIO (LibPQ.execParams conn queryBytes [] LibPQ.Binary) >>= \case + liftIO (LibPQ.execParams conn queryBytes (toParamList params) LibPQ.Binary) >>= \case Nothing -> except $ Left ErrorNoResult Just result -> do @@ -60,13 +67,21 @@ execParams conn query = do Nothing -> pure result Just message -> except $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message -fetch_ :: forall a. FromRow a => Connection -> Text -> IO (Either Error [a]) -fetch_ conn query = runExceptT $ do - result <- execParams conn query - columnTable <- ExceptT $ getColumnTable @a Proxy result +fetch + :: forall a b. (ToParamList a, FromRow b) + => Connection + -> Text + -> a + -> IO (Either Error [b]) +fetch conn query params = runExceptT $ do + result <- execParams conn query params + columnTable <- ExceptT $ getColumnTable @b Proxy result nRows <- liftIO $ LibPQ.ntuples result mapM (ExceptT . fromRow result columnTable) [0..nRows - 1] +fetch_ :: forall a. FromRow a => Connection -> Text -> IO (Either Error [a]) +fetch_ conn query = fetch conn query () + newtype ColumnTable = ColumnTable (Vector (Column, Oid)) deriving (Eq, Show) diff --git a/lib/Database/PostgreSQL/Opium/ToField.hs b/lib/Database/PostgreSQL/Opium/ToField.hs new file mode 100644 index 0000000..0697e0c --- /dev/null +++ b/lib/Database/PostgreSQL/Opium/ToField.hs @@ -0,0 +1,50 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeApplications #-} + +module Database.PostgreSQL.Opium.ToField + ( ToField (..) + ) where + +import Data.Bits (Bits (..)) +import Data.ByteString (ByteString) +import Data.List (singleton) +import Data.Text (Text) +import Data.Word (Word32) +import Database.PostgreSQL.LibPQ (Format (..), Oid) +import Unsafe.Coerce (unsafeCoerce) + +import qualified Data.ByteString as BS +import qualified Data.Text as Text +import qualified Data.Text.Encoding as Encoding +import qualified Database.PostgreSQL.Opium.Oid as Oid + +class ToField a where + toField :: a -> Maybe (Oid, ByteString, Format) + +instance ToField ByteString where + toField x = Just (Oid.bytea, x, Binary) + +instance ToField Text where + toField x = Just (Oid.text, Encoding.encodeUtf8 x, Binary) + +instance ToField String where + toField = toField . Text.pack + +instance ToField Char where + toField = toField . singleton + +-- Potentially slow, but good enough for now +encodeBigEndian :: (Integral a, Bits a) => Int -> a -> ByteString +encodeBigEndian n = BS.pack . go [] n + where + go acc 0 _ = acc + go acc i x = go (fromIntegral (x .&. 0xff) : acc) (i - 1) (x `shiftR` 8) + +instance ToField Int where + toField x = Just (Oid.bigint, encodeBigEndian 8 x, Binary) + +instance ToField Float where + toField x = Just (Oid.real, encodeBigEndian @Word32 4 $ unsafeCoerce x, Binary) + +instance ToField Double where + toField x = Just (Oid.doublePrecision, encodeBigEndian @Word 8 $ unsafeCoerce x, Binary) diff --git a/lib/Database/PostgreSQL/Opium/ToParamList.hs b/lib/Database/PostgreSQL/Opium/ToParamList.hs new file mode 100644 index 0000000..1120cd0 --- /dev/null +++ b/lib/Database/PostgreSQL/Opium/ToParamList.hs @@ -0,0 +1,62 @@ +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeOperators #-} + +module Database.PostgreSQL.Opium.ToParamList + ( ToParamList (..) + ) where + +import Data.ByteString (ByteString) +import Data.Functor.Identity (Identity) +import Database.PostgreSQL.LibPQ (Format, Oid) +import GHC.Generics (Generic, K1 (..), M1 (..), Rec0, Rep, U1 (..), from, (:*:) (..)) + +import Database.PostgreSQL.Opium.ToField (ToField (..)) + +class ToParamList a where + toParamList :: a -> [Maybe (Oid, ByteString, Format)] + default toParamList :: (Generic a, ToParamList' (Rep a)) => a -> [Maybe (Oid, ByteString, Format)] + toParamList = toParamList' . from + +instance ToField a => ToParamList [a] where + toParamList = map toField + +instance ToParamList () where + +instance ToField a => ToParamList (Identity a) where + +instance (ToField a, ToField b) => ToParamList (a, b) where + +instance (ToField a, ToField b, ToField c) => ToParamList (a, b, c) where + +instance (ToField a, ToField b, ToField c, ToField d) => ToParamList (a, b, c, d) where + +instance (ToField a, ToField b, ToField c, ToField d, ToField e) => ToParamList (a, b, c, d, e) where + +instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f) => ToParamList (a, b, c, d, e, f) where + +instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, ToField g) => ToParamList (a, b, c, d, e, f, g) where + +instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, ToField g, ToField h) => ToParamList (a, b, c, d, e, f, g, h) where + +instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, ToField g, ToField h, ToField i) => ToParamList (a, b, c, d, e, f, g, h, i) where + +instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, ToField g, ToField h, ToField i, ToField j) => ToParamList (a, b, c, d, e, f, g, h, i, j) where + +instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, ToField g, ToField h, ToField i, ToField j, ToField k) => ToParamList (a, b, c, d, e, f, g, h, i, j, k) where + +class ToParamList' f where + toParamList' :: f p -> [Maybe (Oid, ByteString, Format)] + +instance ToField t => ToParamList' (Rec0 t) where + toParamList' (K1 x) = [toField x] + +instance ToParamList' f => ToParamList' (M1 t c f) where + toParamList' (M1 x) = toParamList' x + +instance ToParamList' U1 where + toParamList' U1 = [] + +instance (ToParamList' f, ToParamList' g) => ToParamList' (f :*: g) where + toParamList' (x :*: y) = toParamList' x ++ toParamList' y diff --git a/test/Database/PostgreSQL/OpiumSpec.hs b/test/Database/PostgreSQL/OpiumSpec.hs index a1ae811..ccfdfff 100644 --- a/test/Database/PostgreSQL/OpiumSpec.hs +++ b/test/Database/PostgreSQL/OpiumSpec.hs @@ -7,6 +7,7 @@ module Database.PostgreSQL.OpiumSpec (spec) where import Data.ByteString (ByteString) +import Data.Functor.Identity (Identity (..)) import Data.Proxy (Proxy (..)) import Data.Text (Text) import Database.PostgreSQL.LibPQ (Connection) @@ -46,6 +47,12 @@ data ScoreByAge = ScoreByAge instance Opium.FromRow ScoreByAge where +data Only a = Only + { only :: a + } deriving (Eq, Generic, Show) + +instance Opium.FromField a => Opium.FromRow (Only a) where + isLeft :: Either a b -> Bool isLeft (Left _) = True isLeft _ = False @@ -117,6 +124,15 @@ spec = do row <- Opium.fromRow result columnTable 0 row `shouldBe` Right (ManyFields "abc" 42 1.0 "test" True) + describe "fetch" $ do + it "Passes numbered parameters and retrieves a list of rows" $ \conn -> do + rows <- Opium.fetch conn "SELECT ($1 + $2) AS only" (17 :: Int, 25 :: Int) + rows `shouldBe` Right [Only (42 :: Int)] + + it "Uses Identity to pass single parameters" $ \conn -> do + rows <- Opium.fetch conn "SELECT count(*) AS only FROM person WHERE name = $1" $ Identity ("paul" :: Text) + rows `shouldBe` Right [Only (1 :: Int)] + describe "fetch_" $ do it "Retrieves a list of rows" $ \conn -> do rows <- Opium.fetch_ conn "SELECT * FROM person"