diff --git a/lib/Database/PostgreSQL/Opium.hs b/lib/Database/PostgreSQL/Opium.hs index c77c514..790ca83 100644 --- a/lib/Database/PostgreSQL/Opium.hs +++ b/lib/Database/PostgreSQL/Opium.hs @@ -2,51 +2,71 @@ {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} module Database.PostgreSQL.Opium - ( FromField (..) + ( Error (..) + , FieldError (..) + , FromField (..) , FromRow (..) , fetch_ ) where -import Control.Monad.Trans.Maybe (MaybeT (..)) +import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT) import Data.ByteString (ByteString) -import Data.Proxy (Proxy (Proxy)) +import Data.Proxy (Proxy (..)) +import Data.Text (Text) import Database.PostgreSQL.LibPQ - ( Connection + ( Column + , Connection , Result , Row ) import GHC.Generics (C, D, Generic, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..)) -import GHC.TypeLits (KnownSymbol, symbolVal) -import Text.Printf (printf) +import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) import qualified Data.Text as Text import qualified Data.Text.Encoding as Encoding import qualified Database.PostgreSQL.LibPQ as LibPQ -import Database.PostgreSQL.Opium.FromField (FromField (..)) +import Database.PostgreSQL.Opium.Error (Error (..)) +import Database.PostgreSQL.Opium.FromField (FieldError (..), FromField (..)) -fetch_ :: FromRow a => Connection -> ByteString -> IO (Maybe [a]) -fetch_ conn query = runMaybeT $ do - result <- MaybeT $ LibPQ.execParams conn query [] LibPQ.Text - MaybeT $ fetchResult result +execParams :: Connection -> ByteString -> IO (Either Error Result) +execParams conn query = do + LibPQ.execParams conn query [] LibPQ.Text >>= \case + Nothing -> + pure $ Left ErrorNoResult + Just result -> do + status <- LibPQ.resultStatus result + mbMessage <- LibPQ.resultErrorMessage result + case mbMessage of + Just "" -> pure $ Right result + Just message -> pure $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message + Nothing -> pure $ Right result -fetchResult :: FromRow a => Result -> IO (Maybe [a]) +fetch_ :: FromRow a => Connection -> ByteString -> IO (Either Error [a]) +fetch_ conn query = runExceptT $ do + result <- ExceptT $ execParams conn query + ExceptT $ fetchResult result + +fetchResult :: FromRow a => Result -> IO (Either Error [a]) fetchResult result = do nRows <- LibPQ.ntuples result - runMaybeT $ mapM (MaybeT . flip fromRow result) [0..nRows - 1] + runExceptT $ mapM (ExceptT . flip fromRow result) [0..nRows - 1] class FromRow a where - fromRow :: Row -> Result -> IO (Maybe a) - default fromRow :: (Generic a, FromRow' (Rep a)) => Row -> Result -> IO (Maybe a) + fromRow :: Row -> Result -> IO (Either Error a) + default fromRow :: (Generic a, FromRow' (Rep a)) => Row -> Result -> IO (Either Error a) fromRow row result = fmap to <$> fromRow' row result class FromRow' f where - fromRow' :: Row -> Result -> IO (Maybe (f p)) + fromRow' :: Row -> Result -> IO (Either Error (f p)) instance FromRow' f => FromRow' (M1 D c f) where fromRow' row result = fmap M1 <$> fromRow' row result @@ -60,20 +80,48 @@ instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where z <- fromRow' row result pure $ (:*:) <$> y <*> z -instance (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where - fromRow' row result = do - mbColumn <- LibPQ.fnumber result name - case mbColumn of - Nothing -> pure Nothing - Just column -> do - mbField <- LibPQ.getvalue result row column - ty <- LibPQ.ftype result column - case fromField ty . Encoding.decodeUtf8 =<< mbField of - Nothing -> do - format <- LibPQ.fformat result column - printf "field %s: %s (oid: %s, format: %s)\n" (show name) (show mbField) (show ty) (show format) - pure Nothing - Just value -> - pure $ Just $ M1 $ K1 value +-- TODO: Can we clean this up? + +decodeField + :: FromField t + => Text + -> (Row -> Maybe t -> Either Error t') + -> Row + -> Result + -> IO (Either Error (M1 S ('MetaSel ('Just (nameSym :: Symbol)) nu ns dl) (Rec0 t') p)) +decodeField nameText g row result = runExceptT $ do + column <- getColumn + oid <- ExceptT $ Right <$> LibPQ.ftype result column + mbField <- getValue column + value <- case mbField of + Nothing -> + except $ g row Nothing + Just field -> do + value <- except $ mapLeft (ErrorDecode row nameText) $ fromField oid $ Encoding.decodeUtf8 field + except $ g row $ Just value + pure $ M1 $ K1 value where - name = Encoding.encodeUtf8 $ Text.pack $ symbolVal (Proxy :: Proxy nameSym) + name = Encoding.encodeUtf8 nameText + + getColumn :: ExceptT Error IO Column + getColumn = ExceptT $ + maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name + + getValue :: Column -> ExceptT Error IO (Maybe ByteString) + getValue column = ExceptT $ Right <$> LibPQ.getvalue result row column + +instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where + fromRow' = decodeField nameText $ \row -> \case + Nothing -> Left $ ErrorUnexpectedNull row nameText + Just value -> Right value + where + nameText = Text.pack $ symbolVal (Proxy :: Proxy nameSym) + +instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where + fromRow' = decodeField nameText $ const Right + where + nameText = Text.pack $ symbolVal (Proxy :: Proxy nameSym) + +mapLeft :: (b -> c) -> Either b a -> Either c a +mapLeft f (Left l) = Left $ f l +mapLeft _ (Right r) = Right r diff --git a/lib/Database/PostgreSQL/Opium/Error.hs b/lib/Database/PostgreSQL/Opium/Error.hs new file mode 100644 index 0000000..e5c7f07 --- /dev/null +++ b/lib/Database/PostgreSQL/Opium/Error.hs @@ -0,0 +1,14 @@ +module Database.PostgreSQL.Opium.Error (Error (..)) where + +import Data.Text (Text) +import Database.PostgreSQL.LibPQ (ExecStatus, Row) + +import Database.PostgreSQL.Opium.FromField (FieldError) + +data Error + = ErrorDecode Row Text FieldError + | ErrorNoResult + | ErrorInvalidResult ExecStatus Text + | ErrorMissingColumn Text + | ErrorUnexpectedNull Row Text + deriving (Eq, Show) diff --git a/lib/Database/PostgreSQL/Opium/FromField.hs b/lib/Database/PostgreSQL/Opium/FromField.hs index 9dd61e4..c06731c 100644 --- a/lib/Database/PostgreSQL/Opium/FromField.hs +++ b/lib/Database/PostgreSQL/Opium/FromField.hs @@ -1,13 +1,17 @@ -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE UndecidableInstances #-} -module Database.PostgreSQL.Opium.FromField (FromField (..)) where +module Database.PostgreSQL.Opium.FromField + ( FieldError (..) + , FromField (..) + ) where import Data.Attoparsec.Text ( Parser , decimal , parseOnly , signed + , takeText ) import Data.Text (Text) import Database.PostgreSQL.LibPQ (Oid) @@ -18,23 +22,30 @@ import qualified Database.PostgreSQL.Opium.Oid as Oid (\/) :: (a -> Bool) -> (a -> Bool) -> a -> Bool p \/ q = \x -> p x || q x -eitherToMaybe :: Either b a -> Maybe a -eitherToMaybe = \case - Left _ -> Nothing - Right x -> Just x +data FieldError + = FieldErrorUnexpectedNull + | FieldErrorInvalidOid Oid + | FieldErrorInvalidField Oid Text String + deriving (Eq, Show) + +mapLeft :: (b -> c) -> Either b a -> Either c a +mapLeft f (Left l) = Left $ f l +mapLeft _ (Right r) = Right r fromParser :: (Oid -> Bool) -> Parser a -> Oid -> Text - -> Maybe a -fromParser validOid parser oid value - | validOid oid = eitherToMaybe $ parseOnly parser value - | otherwise = Nothing + -> Either FieldError a +fromParser validOid parser oid field + | validOid oid = + mapLeft (FieldErrorInvalidField oid field) $ parseOnly parser field + | otherwise = + Left $ FieldErrorInvalidOid oid class FromField a where - fromField :: Oid -> Text -> Maybe a + fromField :: Oid -> Text -> Either FieldError a instance FromField Int where fromField = fromParser @@ -42,11 +53,9 @@ instance FromField Int where (signed decimal) instance FromField Text where - fromField oid text = - if Oid.text oid || Oid.character oid || Oid.characterVarying oid then - Just text - else - Nothing + fromField = fromParser + (Oid.text \/ Oid.character \/ Oid.characterVarying) + takeText instance FromField String where fromField oid text = Text.unpack <$> fromField oid text diff --git a/opium.cabal b/opium.cabal index dda2388..4b97eb6 100644 --- a/opium.cabal +++ b/opium.cabal @@ -64,6 +64,7 @@ library -- Modules included in this library but not exported. other-modules: + Database.PostgreSQL.Opium.Error, Database.PostgreSQL.Opium.FromField, Database.PostgreSQL.Opium.Oid diff --git a/test/Database/PostgreSQL/Opium/FromFieldSpec.hs b/test/Database/PostgreSQL/Opium/FromFieldSpec.hs index 7a31547..c3d4d97 100644 --- a/test/Database/PostgreSQL/Opium/FromFieldSpec.hs +++ b/test/Database/PostgreSQL/Opium/FromFieldSpec.hs @@ -32,8 +32,8 @@ instance FromRow SingleString where shouldFetch :: (Eq a, FromRow a, Show a) => Connection -> ByteString -> [a] -> IO () shouldFetch conn query expectedRows = do - Just actualRows <- Opium.fetch_ conn query - actualRows `shouldBe` expectedRows + actualRows <- Opium.fetch_ conn query + actualRows `shouldBe` Right expectedRows spec :: SpecWith Connection spec = do diff --git a/test/Database/PostgreSQL/OpiumSpec.hs b/test/Database/PostgreSQL/OpiumSpec.hs index 4f779b7..f09eeb3 100644 --- a/test/Database/PostgreSQL/OpiumSpec.hs +++ b/test/Database/PostgreSQL/OpiumSpec.hs @@ -7,7 +7,7 @@ module Database.PostgreSQL.OpiumSpec (spec) where import Data.Text (Text) import Database.PostgreSQL.LibPQ (Connection) import GHC.Generics (Generic) -import Test.Hspec (SpecWith, describe, it, shouldBe) +import Test.Hspec (SpecWith, describe, it, shouldBe, shouldSatisfy) import qualified Database.PostgreSQL.LibPQ as LibPQ import qualified Database.PostgreSQL.Opium as Opium @@ -19,6 +19,16 @@ data Person = Person instance Opium.FromRow Person where +newtype MaybeTest = MaybeTest + { a :: Maybe String + } deriving (Eq, Generic, Show) + +instance Opium.FromRow MaybeTest where + +isLeft :: Either a b -> Bool +isLeft (Left _) = True +isLeft _ = False + spec :: SpecWith Connection spec = do describe "fromRow" $ do @@ -26,12 +36,40 @@ spec = do Just result <- LibPQ.execParams conn "SELECT * FROM person" [] LibPQ.Text row0 <- Opium.fromRow @Person (LibPQ.Row 0) result - row0 `shouldBe` Just (Person "paul" 25) + row0 `shouldBe` Right (Person "paul" 25) row1 <- Opium.fromRow @Person (LibPQ.Row 1) result - row1 `shouldBe` Just (Person "albus" 103) + row1 `shouldBe` Right (Person "albus" 103) + + it "decodes NULL into Nothing for Maybes" $ \conn -> do + Just result <- LibPQ.execParams conn "SELECT NULL AS a" [] LibPQ.Text + + row <- Opium.fromRow (LibPQ.Row 0) result + row `shouldBe` Right (MaybeTest Nothing) + + it "decodes values into Just for Maybes" $ \conn -> do + Just result <- LibPQ.execParams conn "SELECT 'abc' AS a" [] LibPQ.Text + + row <- Opium.fromRow (LibPQ.Row 0) result + row `shouldBe` Right (MaybeTest $ Just "abc") describe "fetch_" $ do it "retrieves a list of rows" $ \conn -> do rows <- Opium.fetch_ conn "SELECT * FROM person" - rows `shouldBe` Just [Person "paul" 25, Person "albus" 103] + rows `shouldBe` Right [Person "paul" 25, Person "albus" 103] + + it "fails for invalid queries" $ \conn -> do + rows <- Opium.fetch_ @Person conn "MRTLBRNFT" + rows `shouldSatisfy` isLeft + + it "fails for missing columns" $ \conn -> do + rows <- Opium.fetch_ @Person conn "SELECT name FROM person" + rows `shouldBe` Left (Opium.ErrorMissingColumn "age") + + it "fails for unexpected NULLs" $ \conn -> do + rows <- Opium.fetch_ @Person conn "SELECT NULL AS name, 0 AS age" + rows `shouldBe` Left (Opium.ErrorUnexpectedNull (LibPQ.Row 0) "name") + + it "fails for the wrong column type" $ \conn -> do + rows <- Opium.fetch_ @Person conn "SELECT 'quby' AS name, 'indeterminate' AS age" + rows `shouldBe` Left (Opium.ErrorDecode (LibPQ.Row 0) "age" $ Opium.FieldErrorInvalidOid $ LibPQ.Oid 25) diff --git a/test/SpecHook.hs b/test/SpecHook.hs index d02b716..07fa76c 100644 --- a/test/SpecHook.hs +++ b/test/SpecHook.hs @@ -23,7 +23,7 @@ setupConnection = do conn <- LibPQ.connectdb $ Encoding.encodeUtf8 $ Text.pack dsn _ <- LibPQ.setClientEncoding conn "UTF8" - _ <- LibPQ.exec conn "CREATE TABLE person (name TEXT NOT NULL, age INT NOT NULL)" + _ <- LibPQ.exec conn "CREATE TABLE person (name TEXT NOT NULL, age INT NOT NULL, motto TEXT)" _ <- LibPQ.exec conn "INSERT INTO person VALUES ('paul', 25), ('albus', 103)" pure conn