#!/usr/bin/env python3
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

import collections
import io
import json
import os
import sys

from tabulate import tabulate
import matplotlib.pyplot as plt

LIBRARIES = ("orjson", "ujson", "rapidjson", "simplejson", "json")


def aggregate():
    benchmarks_dir = os.path.join(".benchmarks", os.listdir(".benchmarks")[0])
    res = collections.defaultdict(dict)
    for filename in os.listdir(benchmarks_dir):
        with open(os.path.join(benchmarks_dir, filename), "r") as fileh:
            data = json.loads(fileh.read())

        for each in data["benchmarks"]:
            res[each["group"]][each["extra_info"]["lib"]] = {
                "data": [val * 1000 for val in each["stats"]["data"]],
                "median": each["stats"]["median"] * 1000,
                "ops": each["stats"]["ops"],
                "correct": each["extra_info"]["correct"],
            }
    return res


def box(obj):
    for group, val in sorted(obj.items()):
        data = []
        for lib in LIBRARIES:
            data.append(val[lib]["data"] if val[lib]["correct"] else -1)
        fig = plt.figure(1, figsize=(9, 6))
        ax = fig.add_subplot(111)
        bp = ax.boxplot(data, vert=False, labels=LIBRARIES)
        ax.set_xlim(left=0)
        ax.set_xlabel("milliseconds")
        plt.title(group)
        plt.savefig("doc/{}.png".format(group.replace(" ", "_").replace(".json", "")))
        plt.close()


def tab(obj):
    buf = io.StringIO()
    headers = (
        "Library",
        "Median latency (milliseconds)",
        "Operations per second",
        "Relative (latency)",
    )
    for group, val in sorted(obj.items(), reverse=True):
        buf.write("\n" + "#### " + group + "\n\n")
        table = []
        for lib in LIBRARIES:
            correct = val[lib]["correct"]
            table.append(
                [
                    lib,
                    val[lib]["median"] if correct else None,
                    "%.1f" % val[lib]["ops"] if correct else None,
                    0,
                ]
            )
        baseline = table[0][1]
        for each in table:
            each[3] = (
                "%.2f" % (each[1] / baseline) if isinstance(each[1], float) else None
            )
            each[1] = "%.2f" % each[1] if isinstance(each[1], float) else None
        buf.write(tabulate(table, headers, tablefmt="grid") + "\n")

    print(
        buf.getvalue()
        .replace("-", "")
        .replace("=", "-")
        .replace("+", "|")
        .replace("|||||", "")
        .replace("\n\n", "\n")
    )


try:
    locals()[sys.argv[1]](aggregate())
except KeyError:
    sys.stderr.write("usage: graph (box|tab)\n")
    sys.exit(1)
