{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}

module Database.PostgreSQL.Opium.FromRow
  -- * FromRow
  ( FromRow (..)
  -- * Internal
  , toListColumnTable
  ) where

import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (ExceptT (..), except, runExceptT)
import Data.ByteString (ByteString)
import Data.Bifunctor (first)
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Text (Text)
import Data.Vector (Vector)
import Database.PostgreSQL.LibPQ
  ( Column
  , Oid
  , Result
  , Row
  )
import GHC.Generics (Generic, C, D, K1 (..), M1 (..), Meta (..), Rec0, Rep, S, to, (:*:) (..))
import GHC.TypeLits (KnownNat, KnownSymbol, Nat, natVal, symbolVal, type (+))

import qualified Data.Text as Text
import qualified Data.Text.Encoding as Encoding
import qualified Data.Vector as Vector
import qualified Database.PostgreSQL.LibPQ as LibPQ

import Database.PostgreSQL.Opium.Error (Error (..), ErrorPosition (..))
import Database.PostgreSQL.Opium.FromField (FromField (..), fromField)

class FromRow a where
  getColumnTable :: Proxy a -> Result -> IO (Either Error ColumnTable)
  default getColumnTable :: (Generic a, GetColumnTable' (Rep a)) => Proxy a -> Result -> IO (Either Error ColumnTable)
  getColumnTable Proxy = runExceptT . fmap newColumnTable . getColumnTable' @(Rep a) Proxy

  fromRow :: Result -> ColumnTable -> Row -> IO (Either Error a)
  default fromRow :: (Generic a, FromRow' 0 (Rep a)) => Result -> ColumnTable -> Row -> IO (Either Error a)
  fromRow result columnTable row =
    runExceptT $ to <$> fromRow' @0 FRProxy (FromRowCtx result columnTable) row

newtype ColumnTable = ColumnTable (Vector (Column, Oid))
  deriving (Eq, Show)

newColumnTable :: [(Column, Oid)] -> ColumnTable
newColumnTable = ColumnTable . Vector.fromList

indexColumnTable :: ColumnTable -> Int -> (Column, Oid)
indexColumnTable (ColumnTable v) i = v `Vector.unsafeIndex` i

toListColumnTable :: ColumnTable -> [(Column, Oid)]
toListColumnTable (ColumnTable v) = Vector.toList v

class GetColumnTable' f where
  getColumnTable' :: Proxy (f p) -> Result -> ExceptT Error IO [(Column, Oid)]

instance GetColumnTable' f => GetColumnTable' (M1 D c f) where
  getColumnTable' Proxy = getColumnTable' @f Proxy

instance GetColumnTable' f => GetColumnTable' (M1 C c f) where
  getColumnTable' Proxy = getColumnTable' @f Proxy

instance (GetColumnTable' f, GetColumnTable' g) => GetColumnTable' (f :*: g) where
  getColumnTable' Proxy result =
    (++) <$> getColumnTable' @f Proxy result <*> getColumnTable' @g Proxy result

instance {-# OVERLAPPABLE #-} (FromField t, KnownSymbol nameSym) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
  getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy

instance {-# OVERLAPPING #-} (KnownSymbol nameSym, FromField t) => GetColumnTable' (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
  getColumnTable' Proxy = checkColumn @t Proxy $ symbolVal @nameSym Proxy

-- | State kept for a call to 'fromRow'.
data FromRowCtx = FromRowCtx
  Result -- ^ Obtained from 'LibPQ.execParams'.
  ColumnTable -- ^ 'Vector' of expected columns indices and OIDs.

data FRProxy (n :: Nat) (f :: Type -> Type) = FRProxy

class FromRow' (n :: Nat) (f :: Type -> Type) where
  type Members f :: Nat

  fromRow' :: FRProxy n f -> FromRowCtx -> Row -> ExceptT Error IO (f p)

instance FromRow' n f => FromRow' n (M1 D c f) where
  type Members (M1 D c f) = Members f

  fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row

instance FromRow' n f => FromRow' n (M1 C c f) where
  type Members (M1 C c f) = Members f

  fromRow' FRProxy ctx row = M1 <$> fromRow' @n FRProxy ctx row

instance (FromRow' n f, FromRow' (n + Members f) g) => FromRow' n (f :*: g) where
  type Members (f :*: g) = Members f + Members g

  fromRow' FRProxy ctx row = (:*:) <$> fromRow' @n FRProxy ctx row <*> fromRow' @(n + Members f) FRProxy ctx row

instance {-# OVERLAPPABLE #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) where
  type Members (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 t)) = 1

  fromRow' FRProxy = decodeField memberIndex nameText $ \row ->
    maybe (Left $ ErrorUnexpectedNull $ ErrorPosition row nameText) Right
    where
      memberIndex = fromIntegral $ natVal @n Proxy
      nameText = Text.pack $ symbolVal @nameSym Proxy

instance {-# OVERLAPPING #-} (KnownNat n, KnownSymbol nameSym, FromField t) => FromRow' n (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) where
  type Members (M1 S ('MetaSel ('Just nameSym) nu ns dl) (Rec0 (Maybe t))) = 1

  fromRow' FRProxy = decodeField memberIndex nameText $ const pure
    where
      memberIndex = fromIntegral $ natVal @n Proxy
      nameText = Text.pack $ symbolVal @nameSym Proxy

checkColumn :: forall a. FromField a => Proxy a -> String -> Result -> ExceptT Error IO [(Column, Oid)]
checkColumn Proxy nameStr result = do
  column <- ExceptT $ maybe (Left $ ErrorMissingColumn nameText) Right <$> LibPQ.fnumber result name
  oid <- liftIO $ LibPQ.ftype result column
  if validOid @a Proxy oid then
    pure [(column, oid)]
  else
    except $ Left $ ErrorInvalidOid nameText oid
  where
    nameText = Text.pack nameStr
    name = Encoding.encodeUtf8 nameText

decodeField
  :: FromField t
  => Int
  -> Text
  -> (Row -> Maybe t -> Either Error t')
  -> FromRowCtx
  -> Row
  -> ExceptT Error IO (M1 S m (Rec0 t') p)
decodeField memberIndex nameText g (FromRowCtx result columnTable) row = do
  let (column, oid) = columnTable `indexColumnTable` memberIndex
  mbField <- liftIO $ LibPQ.getvalue result row column
  mbValue <- except $ fromFieldIfPresent oid mbField
  value <- except $ g row mbValue
  pure $ M1 $ K1 value
  where
    fromFieldIfPresent :: FromField u => LibPQ.Oid -> Maybe ByteString -> Either Error (Maybe u)
    fromFieldIfPresent oid = maybe (Right Nothing) $ \field ->
      first
        (ErrorInvalidField (ErrorPosition row nameText) oid field)
        (Just <$> fromField field)