{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}

module Database.PostgreSQL.Opium
  ( Error (..)
  , FieldError (..)
  , FromField (..)
  , FromRow (..)
  , fetch_
  )
  where

import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT)
import Data.ByteString (ByteString)
import Data.Proxy (Proxy (..))
import Data.Text (Text)
import Database.PostgreSQL.LibPQ
  ( Column
  , Connection
  , Result
  , Row
  )
import GHC.Generics (C, D, Generic, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..))
import GHC.TypeLits (KnownSymbol, symbolVal)

import qualified Data.Text as Text
import qualified Data.Text.Encoding as Encoding
import qualified Database.PostgreSQL.LibPQ as LibPQ

import Database.PostgreSQL.Opium.Error (Error (..))
import Database.PostgreSQL.Opium.FromField (FieldError (..), FromField (..))

execParams :: Connection -> ByteString -> IO (Either Error Result)
execParams conn query = do
  LibPQ.execParams conn query [] LibPQ.Text >>= \case
    Nothing ->
      pure $ Left ErrorNoResult
    Just result -> do
      status <- LibPQ.resultStatus result
      mbMessage <- LibPQ.resultErrorMessage result
      case mbMessage of
        Just "" -> pure $ Right result
        Just message -> pure $ Left $ ErrorInvalidResult status $ Encoding.decodeUtf8 message
        Nothing -> pure $ Right result

fetch_ :: FromRow a => Connection -> ByteString -> IO (Either Error [a])
fetch_ conn query = runExceptT $ do
  result <- ExceptT $ execParams conn query
  ExceptT $ fetchResult result

fetchResult :: FromRow a => Result -> IO (Either Error [a])
fetchResult result = do
  nRows <- LibPQ.ntuples result
  runExceptT $ mapM (ExceptT . flip fromRow result) [0..nRows - 1]

class FromRow a where
  fromRow :: Row -> Result -> IO (Either Error a)
  default fromRow :: (Generic a, FromRow' (Rep a)) => Row -> Result -> IO (Either Error a)
  fromRow row result = fmap to <$> fromRow' row result

class FromRow' f where
  fromRow' :: Row -> Result -> IO (Either Error (f p))

instance FromRow' f => FromRow' (M1 D c f) where
  fromRow' row result = fmap M1 <$> fromRow' row result

instance FromRow' f => FromRow' (M1 C c f) where
  fromRow' row result = fmap M1 <$> fromRow' row result

instance (FromRow' f, FromRow' g) => FromRow' (f :*: g) where
  fromRow' row result = do
    y <- fromRow' row result
    z <- fromRow' row result
    pure $ (:*:) <$> y <*> z

decodeField
  :: FromField t
  => Text
  -> (Row -> Maybe t -> Either Error t')
  -> Row
  -> Result
  -> IO (Either Error (M1 S m (Rec0 t') p))
decodeField nameText g row result = runExceptT $ do
    column <- getColumn
    oid <- ExceptT $ pure <$> LibPQ.ftype result column
    mbField <- getFieldText column
    mbValue <- getValue oid mbField
    value <- except $ g row mbValue
    pure $ M1 $ K1 value
    where
      name = Encoding.encodeUtf8 nameText

      getColumn :: ExceptT Error IO Column
      getColumn = ExceptT $
        maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name

      getFieldText :: Column -> ExceptT Error IO (Maybe Text)
      getFieldText column = 
        ExceptT $ Right . fmap Encoding.decodeUtf8 <$> LibPQ.getvalue result row column

      getValue :: FromField u => LibPQ.Oid -> Maybe Text -> ExceptT Error IO (Maybe u)
      getValue oid = except . maybe
        (Right Nothing)
        (fmap Just . mapLeft (ErrorDecode row nameText) . fromField oid)

instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
  fromRow' = decodeField nameText $ \row -> maybe
    (Left $ ErrorUnexpectedNull row nameText)
    Right
    where
      nameText = Text.pack $ symbolVal (Proxy :: Proxy nameSym)

instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => FromRow' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
  fromRow' = decodeField nameText $ const Right
    where
      nameText = Text.pack $ symbolVal (Proxy :: Proxy nameSym)

mapLeft :: (b -> c) -> Either b a -> Either c a
mapLeft f (Left l) = Left $ f l
mapLeft _ (Right r) = Right r