import sys
import os
from matplotlib import pyplot as plt
import numpy as np


marker_id = 0
markers = 'xosv^'
def next_marker():
    global marker_id
    result = markers[marker_id % len(markers)]
    marker_id += 1
    return result

line_id = 0
lines = ['-', '--', '-.']
def next_line():
    global line_id
    result = lines[line_id % len(lines)]
    line_id += 1
    return result

class Experiment:
    def __init__(self, dirname, label):
        self.name = os.path.basename(os.path.normpath(dirname))
        self.label = label
        self.directory = dirname
        self.pairs = []
        
    def plot_roc(self, ax):
        triplets = [] # score negative positive
        
        self.pairs.sort(reverse=True)
        self.pairs = np.array(self.pairs)
        
        total_positive = 0
        total_negative = 0
        for score, gt in self.pairs:
            if len(triplets) == 0 or triplets[-1][0] > score:
                triplets.append([score, 0, 0])
            triplets[-1][1] += 1 - gt
            triplets[-1][2] += gt
            
            total_positive += gt
            total_negative += 1 - gt
        
        triplets = np.array(triplets)
        
        positive_in, negative_in = 0, 0
        xs, ys = [0.0], [0.0]
        
        for score, negative, positive in triplets:
            positive_in += positive
            negative_in += negative
            true_positives = positive_in / total_positive
            false_positives = negative_in / total_negative
            xs.append(false_positives)
            ys.append(true_positives)
        
        if (len(xs) == 3):
            ax.plot(xs[1:2], ys[1:2], label=self.label, marker=next_marker(), markersize=10)
            print('{}: TPR {}, FPR {}'.format(self.label, ys[1], xs[1]))
        else:
            ax.plot(xs, ys, label=self.label, linestyle=next_line())
            for x1, y1, x2, y2 in zip(xs[:-1], ys[:-1], xs[1:], ys[1:]):
                if x2 + y2 > 1:
                    w2 = (1 - (x1 + y1))/((x2 + y2) - (x1 + y1))
                    w1 = 1 - w2
                    x = x1 * w1 + x2 * w2
                    y = y1 * w1 + y2 * w2
                    print('{}: EER {} {}'.format(self.label, x, y))
                    break

if len(sys.argv) < 4:
    print("usage: {} ground_truth title output predictions_1 label_1 [predictions_2 label_2 ...]".format(sys.argv[0]))
    sys.exit(0)

gt_dir = sys.argv[1]
title = sys.argv[2]
output = sys.argv[3]
prediction_dirs = sys.argv[4::2]
labels = sys.argv[5::2]

filenames = set(os.listdir(prediction_dirs[0]))
for prediction_dir in prediction_dirs[1:]:
    filenames = filenames.intersection(os.listdir(prediction_dir))
filenames = list(filenames)

experiments = [Experiment(prediction_dir, label) for prediction_dir, label in zip(prediction_dirs, labels)]


gts = {}
for filename in filenames:
    gt_path = os.path.join(gt_dir, filename)
    gt = []
    with open(gt_path, 'r') as gt_file:
        for line in gt_file:
            gt.append(int(line))
    
    gt = np.array(gt)
    gts[filename] = gt


for experiment in experiments:
    for filename in filenames:
        prediction_path = os.path.join(experiment.directory, filename)
        gt = gts[filename]
        with open(prediction_path, 'r') as pred_file:
            for i, line in enumerate(pred_file):
                score = float(line)
                if gt[i] != 1:
                    experiment.pairs.append((score, gt[i] // 2))
    experiment.plot_roc(plt)

    
plt.legend()
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title(title)
plt.savefig(output)
plt.show()
