diff --git a/lib/Database/PostgreSQL/Opium.hs b/lib/Database/PostgreSQL/Opium.hs index aaf84b3..70eaec1 100644 --- a/lib/Database/PostgreSQL/Opium.hs +++ b/lib/Database/PostgreSQL/Opium.hs @@ -5,6 +5,7 @@ module Database.PostgreSQL.Opium -- * Connection Management ( Connection + , ConnectionError , connect , close -- * Queries @@ -40,7 +41,7 @@ import Database.PostgreSQL.LibPQ (Result) import qualified Data.Text.Encoding as Encoding import qualified Database.PostgreSQL.LibPQ as LibPQ -import Database.PostgreSQL.Opium.Connection (Connection, connect, close, withRawConnection) +import Database.PostgreSQL.Opium.Connection (Connection, ConnectionError, connect, close, withRawConnection) import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..)) import Database.PostgreSQL.Opium.FromField (FromField (..), RawField (..)) import Database.PostgreSQL.Opium.FromRow (FromRow (..), ColumnTable) diff --git a/lib/Database/PostgreSQL/Opium/Connection.hs b/lib/Database/PostgreSQL/Opium/Connection.hs index ff8c091..b637d27 100644 --- a/lib/Database/PostgreSQL/Opium/Connection.hs +++ b/lib/Database/PostgreSQL/Opium/Connection.hs @@ -1,8 +1,10 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} module Database.PostgreSQL.Opium.Connection ( Connection + , ConnectionError , unsafeWithRawConnection , withRawConnection , connect @@ -10,14 +12,20 @@ module Database.PostgreSQL.Opium.Connection ) where import Control.Concurrent.MVar (MVar, newMVar, modifyMVar_, withMVar) -import Data.ByteString (ByteString) +import Data.Maybe (fromMaybe) +import Data.Text (Text) +import Database.PostgreSQL.LibPQ (ConnStatus (..)) import GHC.Stack (HasCallStack) +import qualified Data.Text.Encoding as Encoding import qualified Database.PostgreSQL.LibPQ as LibPQ newtype Connection = Connection { rawConnection :: MVar (Maybe LibPQ.Connection) - } + } + +newtype ConnectionError = ConnectionError Text + deriving (Show) withRawConnection :: HasCallStack @@ -35,11 +43,18 @@ unsafeWithRawConnection f = withRawConnection $ \case Nothing -> error "raw connection is missing! perhaps the connection was already closed." Just rawConn -> f rawConn -connect :: ByteString -> IO Connection +connect :: Text -> IO (Either ConnectionError Connection) connect connectionString = do - rawConn <- LibPQ.connectdb connectionString - - Connection <$> newMVar (Just rawConn) + -- Appending the client_encoding setting overrides any previous setting in the connection string. + -- We set the client encoding here to make sure we can use it below for decoding connection + -- error messages. + rawConn <- LibPQ.connectdb $ Encoding.encodeUtf8 $ connectionString <> " client_encoding=UTF8" + status <- LibPQ.status rawConn + if status == ConnectionOk then + Right . Connection <$> newMVar (Just rawConn) + else do + rawError <- fromMaybe "" <$> LibPQ.errorMessage rawConn + pure $ Left $ ConnectionError $ Encoding.decodeUtf8Lenient rawError close :: Connection -> IO () close conn = modifyMVar_ conn.rawConnection $ \case diff --git a/opium.cabal b/opium.cabal index d7b73bd..f4f75a7 100644 --- a/opium.cabal +++ b/opium.cabal @@ -126,6 +126,7 @@ test-suite opium-test base, opium, bytestring, + containers, hspec, postgresql-libpq, time, diff --git a/test/SpecHook.hs b/test/SpecHook.hs index 507c920..1ffa593 100644 --- a/test/SpecHook.hs +++ b/test/SpecHook.hs @@ -8,7 +8,6 @@ import Test.Hspec (Spec, SpecWith, around) import Text.Printf (printf) import qualified Data.Text as Text -import qualified Data.Text.Encoding as Encoding import qualified Database.PostgreSQL.Opium as Opium @@ -20,7 +19,7 @@ setupConnection = do Just dbPort <- lookupEnv "DB_PORT" let dsn = printf "host=localhost user=%s password=%s dbname=%s port=%s" dbUser dbPass dbName dbPort - conn <- Opium.connect $ Encoding.encodeUtf8 $ Text.pack dsn + Right conn <- Opium.connect $ Text.pack dsn Right _ <- Opium.execute_ "DROP TABLE IF EXISTS person" conn Right _ <- Opium.execute_ "CREATE TABLE person (name TEXT NOT NULL, age INT NOT NULL, score DOUBLE PRECISION NOT NULL, motto TEXT)" conn