from pathlib import Path
import json
import numpy as np
import pandas as pd
import plotly.express as px
from typing import Union, Literal

image_height = 500


def load_data(bench_suite):
    for bench_file in Path(f"benchmarks/results").glob(f"*{bench_suite}*.json"):
        (_, _, jvm, run) = bench_file.stem.split("-")
        for data in json.loads(bench_file.read_bytes()):
            benchmark = data["benchmark"].replace(
                f"benchmarks.{bench_suite}.", "")
            if "params" in data:
                benchmark += "(" + ",".join(data["params"].values()) + ")"
            for fork, times in enumerate(data["primaryMetric"]["rawData"]):
                for time in times:
                    yield (jvm, benchmark, int(run), int(fork), time)


Mode = Union[Literal["bar"], Literal["strip"], Literal["box"]]
Orientation = Union[Literal["h"], Literal["v"]]


def make_bars_graph(bench_suite, mode: Mode, orientation: Orientation):
    df = pd.DataFrame(
        load_data(bench_suite),
        columns=("jvm", "benchmark", "run", "fork", "time")
    )
    if mode == "bar":
        df = df.drop(columns=["run", "fork"])
        df = df.groupby(["jvm", "benchmark"])
        df = df.agg([("median", np.median), ("min", np.min), ("max", np.max)])
        df = df.reset_index()
        df.columns = [' '.join(col).strip() for col in df.columns.values]
        df["time error"] = df["time median"] - df["time min"]
        df["time error minus"] = df["time max"] - df["time median"]
    df = df.sort_values(by=["benchmark"], ascending=False)

    args = {
        "x" if orientation == "h" else "y": "time median" if mode == "bar" else "time",
        "y" if orientation == "h" else "x": "benchmark",
        "color": "jvm",
        "orientation": orientation,
        "color_discrete_sequence": px.colors.qualitative.Pastel1,
        "template": "plotly_white"
    }
    if mode == "bar":
        args |= {
            "barmode": "group",
            "error_x" if orientation == "h" else "error_y": "time error",
            "error_x_minus" if orientation == "h" else "error_y_minus": "time error minus"
        }
    elif mode == "box":
        args |= {
            "points": "all"
        }

    fig = getattr(px, mode)(df, **args)

    layout_args = {
        "title": bench_suite,
        "xaxis_title" if orientation == "h" else "yaxis_title": "Throughput [ops/s] (higher is better)",
        "yaxis_title" if orientation == "h" else "xaxis_title": "Benchmark"
    }
    fig.update_layout(**layout_args)

    fig.write_html(
        f"benchmarks/graphs/{bench_suite}-{mode}-{orientation}.html",
        include_plotlyjs='cdn'
    )
    fig.write_image(
        f"benchmarks/graphs/{bench_suite}-{mode}-{orientation}.svg",
        height=image_height,
        width=image_height * 2
    )


make_bars_graph("SumBenchmark", "bar", "h")