{-# LANGUAGE OverloadedStrings #-}

module SpecHook (hook) where

import Control.Exception (bracket)
import Database.PostgreSQL.LibPQ (Connection)
import System.Environment (lookupEnv)
import Test.Hspec (Spec, SpecWith, around)
import Text.Printf (printf)

import qualified Data.Text as Text
import qualified Data.Text.Encoding as Encoding
import qualified Database.PostgreSQL.LibPQ as LibPQ

setupConnection :: IO Connection
setupConnection = do
  Just dbUser <- lookupEnv "DB_USER"
  Just dbPass <- lookupEnv "DB_PASS"
  Just dbName <- lookupEnv "DB_NAME"
  Just dbPort <- lookupEnv "DB_PORT"

  let dsn = printf "host=localhost user=%s password=%s dbname=%s port=%s" dbUser dbPass dbName dbPort
  conn <- LibPQ.connectdb $ Encoding.encodeUtf8 $ Text.pack dsn
  _ <- LibPQ.setClientEncoding conn "UTF8"

  _ <- LibPQ.exec conn "CREATE TABLE person (name TEXT NOT NULL, age INT NOT NULL, score DOUBLE PRECISION NOT NULL, motto TEXT)"
  _ <- LibPQ.exec conn "INSERT INTO person VALUES ('paul', 25, 30), ('albus', 103, 50.42)"

  pure conn

teardownConnection :: Connection -> IO ()
teardownConnection conn = do
  _ <- LibPQ.exec conn "DROP TABLE person"
  LibPQ.finish conn

class SpecInput a where
  hook :: SpecWith a -> Spec

instance SpecInput Connection where
  hook = around $ bracket setupConnection teardownConnection

instance SpecInput () where
  hook = id