from pathlib import Path
import json
import numpy as np
import pandas as pd
import plotly.express as px
from typing import Union, Literal

image_height = 800


def load_data(bench_suite):
    for bench_file in Path(f"benchmarks/results").glob(f"*{bench_suite}*.json"):
        (_, _, jvm, run) = bench_file.stem.split("-")
        for data in json.loads(bench_file.read_bytes()):
            benchmark = data["benchmark"].replace(
                f"benchmarks.{bench_suite}.", "")
            if "params" in data:
                benchmark += "(" + ",".join(data["params"].values()) + ")"
            for fork, times in enumerate(data["primaryMetric"]["rawData"]):
                for time in times:
                    yield (jvm, benchmark, int(run), int(fork), time)


Scale = Union[Literal["lin"], Literal["log"]]
Mode = Union[Literal["bar"], Literal["strip"], Literal["box"]]
Orientation = Union[Literal["h"], Literal["v"]]


def make_bars_graph(bench_suite, max, mode: Mode, orientation: Orientation, scale: Scale):
    df = pd.DataFrame(
        load_data(bench_suite),
        columns=("jvm", "benchmark", "run", "fork", "time")
    )
    if mode == "bar":
        df = df.drop(columns=["run", "fork"])
        df = df.groupby(["jvm", "benchmark"])
        df = df.agg([("median", np.median), ("min", np.min), ("max", np.max)])
        df = df.reset_index()
        df.columns = [' '.join(col).strip() for col in df.columns.values]
        df["time error minus"] = df["time median"] - df["time min"]
        df["time error"] = df["time max"] - df["time median"]
    reverse_order = orientation == "v"
    df = df.sort_values(by=["jvm"], ascending=not reverse_order)
    df = df.sort_values(by=["benchmark"], ascending=reverse_order)

    args = {
        "x" if orientation == "h" else "y": "time median" if mode == "bar" else "time",
        "y" if orientation == "h" else "x": "benchmark",
        "color": "jvm",
        "orientation": orientation,
        "color_discrete_sequence": px.colors.qualitative.Pastel1,
        "template": "plotly_white"
    }
    if scale == "log":
        args["log_x" if orientation == "h" else "log_y"] = True
    elif scale == "lin":
        args["range_x" if orientation == "h" else "range_y"] = [0, max]
    if mode == "bar":
        args |= {
            "barmode": "group",
            "error_x" if orientation == "h" else "error_y": "time error",
            "error_x_minus" if orientation == "h" else "error_y_minus": "time error minus"
        }

    fig = getattr(px, mode)(df, **args)


    log_note = "" if scale == "lin" else " Note the logarithmic scale."
    layout_args = {
        "title": bench_suite,
        "xaxis_title" if orientation == "h" else "yaxis_title": f"Throughput [ops/s] (Higher is better.{log_note})",
        "yaxis_title" if orientation == "h" else "xaxis_title": "Benchmark",
        "legend_traceorder": "reversed"
    }
    fig.update_layout(**layout_args)

    fig.write_html(
        f"benchmarks/graphs/{bench_suite}-{mode}-{orientation}-{scale}.html",
        include_plotlyjs='cdn'
    )
    fig.write_image(
        f"benchmarks/graphs/{bench_suite}-{mode}-{orientation}-{scale}.svg",
        height=image_height,
        width=image_height * 1.5
    )


make_bars_graph("AppendBenchmark", 450000000, "bar", "h", "log")
make_bars_graph("AppendBenchmark", 450000000, "box", "h", "log")
make_bars_graph("AppendBenchmark", 450000000, "strip", "h", "log")

make_bars_graph("SumBenchmark", 450000, "bar", "h", "lin")
make_bars_graph("SumBenchmark", 450000, "box", "h", "lin")
make_bars_graph("SumBenchmark", 450000, "strip", "h", "lin")