#!/usr/bin/env python


import jinja2
import json
import sys


from dataclasses import dataclass
from pathlib import Path
from typing import Tuple


SIZE = (192, 192, 192)


templates_env = jinja2.Environment(
    loader=jinja2.FileSystemLoader(Path(__file__).parent.parent / "templates"),
    autoescape=jinja2.select_autoescape()
)
        

@dataclass
class Experiment:
    job_name: str
    account: str
    partition: str
    nastja_binary_path: str
    nodes: int
    tasks: int
    num_blocks: Tuple[int, int, int]
    domain_scale: Tuple[int, int, int]
    time: str = "00:15:00"
    extra_sbatch_line: str = ""
    logfile_path: str = "/p/project/cellsinsilico/paulslustigebude/ma/experiments/eval/logs/%x-%A.%a"
    config_path: str = "/p/project/cellsinsilico/paulslustigebude/ma/experiments/eval/generated/config/${SLURM_JOB_NAME}.json"
    output_dir_path: str = "/p/scratch/cellsinsilico/paul/nastja-out/${SLURM_JOB_NAME}-${SLURM_ARRAY_JOB_ID}.${SLURM_ARRAY_TASK_ID}"

    def get_config(self):
        with (Path(__file__).parent.parent / "templates" / "weak.json").open(encoding="utf8") as f:
            config = json.load(f)

        size = (
            SIZE[0] * self.domain_scale[0],
            SIZE[1] * self.domain_scale[1],
            SIZE[2] * self.domain_scale[2],
        )
        blocksize = (
            size[0] // self.num_blocks[0],
            size[1] // self.num_blocks[1],
            size[2] // self.num_blocks[2],
        )
        config["Geometry"] = {
            "blockcount": list(self.num_blocks),
            "blocksize": list(blocksize),
        }

        cells_filling = [{
            "box": [
                [0, 0, 0],
                list(size)
            ],
            "celltype": 0,
            "component": 0,
            "pattern": "const",
            "seed": 0,
            "shape": "cube",
            "value": 0,
        }]
        for z in range(self.domain_scale[2]):
            for y in range(self.domain_scale[1]):
                for x in range(self.domain_scale[0]):
                    cx = x * SIZE[0] + SIZE[0] // 2
                    cy = y * SIZE[1] + SIZE[1] // 2
                    cz = z * SIZE[2] + SIZE[2] // 2
                    cells_filling.append({
                        "shape": "sphere",
                        "pattern": "voronoi",
                        "count": 715,
                        "radius": 38,
                        "center": [cx, cy, cz],
                        "box": [
                            [cx - 38, cy - 38, cz - 38],
                            [cx + 38, cy + 38, cz + 38]
                        ],
                        "celltype": 9,
                        "seed": 758960,
                    })

        config["Filling"]["cells"] = cells_filling

        return config
    
    def write_batch_file(self, out_path: Path):
        t = templates_env.get_template("strong-batch.j2")
        t.stream(
            name=self.job_name,
            account=self.account,
            partition=self.partition,
            nodes=self.nodes,
            tasks=self.tasks,
            extra_sbatch_line=self.extra_sbatch_line,
            time=self.time,
            logfile_path=self.logfile_path,
            nastja_binary_path=self.nastja_binary_path,
            config_path=self.config_path,
            output_dir_path=self.output_dir_path,
        ).dump(str(out_path))


def make_cpu_ex(x: int, y: int, z: int) -> Experiment:
    num_blocks = x * y * z
    assert num_blocks % 48 == 0
    num_nodes = num_blocks // 48

    assert x % 4 == 0
    assert y % 4 == 0
    assert z % 3 == 0

    return Experiment(
        job_name=f"weak-cpu-{x:02}-{y:02}-{z:02}",
        account="cellsinsilico",
        partition="batch",
        nastja_binary_path="/p/project/cellsinsilico/paulslustigebude/nastja/build-nocuda/nastja",
        nodes=num_nodes,
        tasks=num_blocks,
        num_blocks=(x, y, z),
        domain_scale=(x // 4, y // 4, z // 3),
    )


def make_gpu_ex(x: int, y: int, z: int) -> Experiment:
    num_blocks = x * y * z
    gpus_per_node = num_blocks if num_blocks <= 4 else 4
    num_nodes = 1 if num_blocks <= 4 else num_blocks // 4

    return Experiment(
        job_name=f"weak-gpu-{x:02}-{y:02}-{z:02}",
        account="cellsinsilico",
        partition="gpus",
        extra_sbatch_line=f"#SBATCH --gres=gpu:{gpus_per_node}",
        nastja_binary_path="/p/project/cellsinsilico/paulslustigebude/nastja/build-cuda/nastja",
        nodes=num_nodes,
        tasks=num_blocks,
        num_blocks=(x, y, z),
        domain_scale=(x, y, z),
    )


def make_booster_ex(x: int, y: int, z: int) -> Experiment:
    num_blocks = x * y * z
    gpus_per_node = num_blocks if num_blocks <= 4 else 4
    num_nodes = 1 if num_blocks <= 4 else num_blocks // 4

    return Experiment(
        job_name=f"weak-booster-{x:02}-{y:02}-{z:02}",
        account="hkf6",
        partition="booster",
        extra_sbatch_line=f"#SBATCH --gres=gpu:{gpus_per_node}",
        nastja_binary_path="/p/project/cellsinsilico/paulslustigebude/nastja/build-cuda/nastja",
        nodes=num_nodes,
        tasks=num_blocks,
        num_blocks=(x, y, z),
        domain_scale=(x, y, z),
    )


experiments = [
    make_cpu_ex(4, 4, 3),
    make_cpu_ex(4, 4, 6),
    make_cpu_ex(4, 4, 12),
    make_cpu_ex(4, 8, 12),
    make_cpu_ex(8, 8, 12),
    make_cpu_ex(8, 8, 24),
    make_cpu_ex(8, 16, 24),
    make_cpu_ex(16, 16, 24),
    make_cpu_ex(16, 16, 48),

    make_gpu_ex(1, 1, 1),
    make_gpu_ex(1, 1, 2),
    make_gpu_ex(1, 2, 2),
    make_gpu_ex(2, 2, 2),
    make_gpu_ex(2, 2, 4),
    make_gpu_ex(2, 4, 4),
    make_gpu_ex(4, 4, 4),
    make_gpu_ex(4, 4, 8),

    make_booster_ex(1, 1, 1),
    make_booster_ex(1, 1, 2),
    make_booster_ex(1, 2, 2),
    make_booster_ex(2, 2, 2),
    make_booster_ex(2, 2, 4),
    make_booster_ex(2, 4, 4),
    make_booster_ex(4, 4, 4),
    make_booster_ex(4, 4, 8),
]

if __name__ == "__main__":
    outdir = Path(__file__).parent.parent / "generated"

    for e in experiments:
        print(f"Generating config for {e.job_name}", file=sys.stderr)
        config_path = (outdir / "config" / e.job_name).with_suffix(".json")
        with config_path.open("w", encoding="utf8") as f:
            json.dump(e.get_config(), f, indent=2)
        print(f"Generating batch file for {e.job_name}", file=sys.stderr)
        e.write_batch_file(outdir / "batch" / e.job_name)