from argparse import ArgumentParser
import os
import numpy as np
import re

from natsort import natsorted

models = [
    "merged",
    "mobile",
    "dense_net_201",
    "inception_resnet_v2",
    "resnet50",
    "inception_v3",
    "vgg19",
    "vgg16"
]

datasets = [
    "cells_no_aug",
    "patterns_no_aug",
    "glomeruli_no_aug",
    "ulg_lbtd2_chimio_necrose",
    "ulg_lbtd_lba_new",
    "ulg_lbtd_tissus",
    "ulb_anapath_lba",
    "ulg_breast"
]


def main(argv):
    # Model loading and saving
    parser = ArgumentParser()
    parser.add_argument("--path", dest="path")
    params, unknown = parser.parse_known_args(argv)
    print("Parameters: {}".format(params))

    metrics = ["cv_best_score", "accuracy", "precision", "recall", "f1-score", "roc_auc", "n_features"]
    files = os.listdir(params.path)
    regex = "^({})_(.*)_?({})_(.*)\.npz$".format("|".join(models), "|".join(datasets))
    matches = [re.match(regex, file) for file in files]
    match_data = [(file, match) for file, match in zip(files, matches) if match is not None]

    if len(match_data) != len(files):
        print("{} file(s) were not matched !".format(len(files) - len(match_data)), file=sys.stderr)
        print("{}".format([file for file, match in zip(files, matches) if match is None]))

    layers_sources = natsorted(np.unique([match.group(2) for _, match in match_data]))
    tasks = np.unique([match.group(4) for _, match in match_data])
    match_dict = {(match.group(1), match.group(2), match.group(3), match.group(4)): file for file, match in match_data}

    for task in tasks:
        for dataset in datasets:
            for model in models:
                 for layer_source in layers_sources:
                    key = (model, layer_source, dataset, task)
                    if key not in match_dict:
                        continue
                    file = match_dict[key]
                    full_path = os.path.join(params.path, file)
                    if os.path.isdir(full_path):
                        continue
                    data = np.load(full_path)
                    factor = 1.0 if dataset != "ulg_lbtd_tissus" else (900 / 888.0)
                    scores = [data[metric] * (factor if metric == "accuracy" else 1.0) for metric in metrics if metric in data]
                    print("{}\t{}".format(file, "\t".join(["{:1.4f}".format(float(v)) for v in scores])))


if __name__ == "__main__":
    import sys
    main(sys.argv[1:])