import copy
import json

from dataclasses import dataclass

configurations = [
    (1, 1, 1),
    (1, 1, 2),
    (1, 2, 2),
    (2, 2, 2),
    (2, 2, 4),
    (2, 4, 4),
    (4, 4, 4),
    (4, 4, 8),
    (4, 8, 8),
    (8, 8, 8)
]

with open("templates/spheroid_weak.json") as template_file:
    template = json.load(template_file)

SIZE_X = 400
SIZE_Y = 400
SIZE_Z = 400

for bx, by, bz in configurations:
    nastja_config = copy.deepcopy(template)
    nastja_config["Geometry"]["blockcount"] = [bx, by, bz]
    nastja_config["Geometry"]["blocksize"] = [SIZE_X, SIZE_Y, SIZE_Z]

    # Fill the whole domain with ECM
    cells_filling = [
        {
            "shape": "cube",
            "box": [
              [0, 0, 0],
              [bx * SIZE_X, by * SIZE_Y, bz * SIZE_Z]
            ],
            "value": 0,
            "celltype": 0
        }
    ]

    # Place a bunch of cells in each block to keep each rank busy
    for z in range(bz):
        for y in range(by):
            for x in range(bx):
                sx = x * SIZE_X
                sy = y * SIZE_Y
                sz = z * SIZE_Z
                cells_filling.append({
                    "shape": "sphere",
                    "pattern": "voronoi",
                    "count": 5500,
                    "radius": 75,
                    "center": [sx + 200, sy + 200, sz + 200],
                    "box": [
                        [sx + 110, sy + 110, sz + 110],
                        [sx + 290, sy + 290, sz + 290]
                    ],
                    "celltype": 9
                })
    nastja_config["Filling"]["cells"] = cells_filling

    ntasks = bx * by * bz
    if ntasks < 4:
        nodes = 1
        gpus_per_node = ntasks
    else:
        assert ntasks % 4 == 0
        nodes = ntasks // 4
        gpus_per_node = 4

    label = f"weak400-t{ntasks:04}n{nodes:03}g{gpus_per_node}x{bx}y{by}z{bz}"

    with open(f"configs/measurements/weak/spheroid_{label}.json", "w") as config_file:
        json.dump(nastja_config, config_file, indent=2)

    batch_config = f"""#!/usr/bin/env bash

#SBATCH --job-name={label}
#SBATCH --account=hkf6
#SBATCH --partition=booster
#SBATCH --nodes={nodes}
#SBATCH --ntasks={ntasks}
# Counted per node
#SBATCH --gres=gpu:{gpus_per_node}
#SBATCH --time=00:15:00
#SBATCH --output=logs/{label}-%A_%a.log
#SBATCH --error=logs/{label}-%A_%a.log
#SBATCH --array=1-5

SOURCE_DIR=/p/project/cellsinsilico/paulslustigebude
OUTPUT_DIR="/p/scratch/cellsinsilico/paul/nastja-out/{label}-${{SLURM_ARRAY_TASK_ID}}"

echo "outdir is ${{OUTPUT_DIR}}"

mkdir -p "${{OUTPUT_DIR}}"
source "${{SOURCE_DIR}}/activate-nastja-modules"

srun --unbuffered "${{SOURCE_DIR}}/nastja/build-cuda/nastja" \\
  -c "${{SOURCE_DIR}}/ma/experiments/configs/measurements/weak/spheroid_{label}.json" \\
  -o "${{OUTPUT_DIR}}"
"""

    with open(f"batch/measurements/weak/{label}", "w", encoding="utf8") as batch_config_file:
        batch_config_file.write(batch_config)