/* 
 * File:   ImageTestSet.cpp
 * Author: Marcel Duris
 * 
 * Created on March 30, 2012, 2:28 PM
 */

#include <cstdio>
#include "ImageTestSet.h"
#include "utils.h"

using namespace cv;
using namespace std;

ImageTestSet::ImageTestSet() {
    groundTruth.clear();
    masks.clear();
    pos = 0;
    imagePath = "";
    maskDescriptionFile = "";
}

ImageTestSet::ImageTestSet(const ImageTestSet& orig) {
    groundTruth.clear();
    masks.clear();
    pos = 0;
}

ImageTestSet::~ImageTestSet() {
}

void ImageTestSet::buildTestSet(string maskDescriptionFile, string imagePath, int count, bool append) {
    this->maskDescriptionFile = maskDescriptionFile;
    this->imagePath = imagePath;
    this->count = count;

    masks.clear();
    loadMaskDesc(maskDescriptionFile, masks);

    if (append) {
        pos = groundTruth.size();
    }
    getTruthFromUser();

    resps.clear();
    pos = 0;
}

void ImageTestSet::getTruthFromUser() {

    vector< int > truth;
    truth.clear();
    namedWindow("Image", CV_WINDOW_NORMAL);
    char input[256], number[256];

    for (int i = pos; i < count; i++) {
        sprintf(input, imagePath.c_str(), i);
        Mat img = imread(input);
        sprintf(number, "%d", i);
        putText(img, string(number), Point(10, 100), FONT_HERSHEY_SIMPLEX, 2, Scalar(255, 255, 0));

        for (int i = 0; i < masks.size(); i++) {
            if (truth.size() == masks.size()) {
                if (truth[i]) {
                    rectangle(img, masks[i].roi, Scalar(0, 255, 0), 2);
                } else {
                    rectangle(img, masks[i].roi, Scalar(0, 0, 255), 2);
                }
            } else {
                rectangle(img, masks[i].roi, Scalar(128, 128, 128), 2);
            }
        }

        vector< maskRegion >::iterator mask_it = masks.begin(), mask_end = masks.end();
        while (mask_it != mask_end) {
            if ((mask_it != masks.begin()) || (truth.size() == 0)) {
                rectangle(img, (*mask_it).roi, Scalar(255, 255, 0), 1);
            }

            imshow("Image", img);
            int key = 0;
            while (key != 'y' && key != 'n' && key != 'q' && key != 'r') {
                key = waitKey();
            }

            switch (key) {
                case 'y':
                    if (mask_it == masks.begin()) {
                        truth.clear();
                    }
                    truth.push_back(1);
                    break;
                case 'n':
                    if (mask_it == masks.begin()) {
                        truth.clear();
                    }
                    truth.push_back(0);
                    break;
                case 'q':
                    return;
                    break;
                case 'r':
                    if (mask_it == masks.begin()) {
                        rectangle(img, (*mask_it).roi, Scalar(128, 128, 128), 2);
                        mask_it = mask_end;
                        continue;
                    }
                    truth.push_back(1);
                    break;
            }
            for (int i = 0; i < masks.size(); i++) {
                rectangle(img, masks[i].roi, Scalar(128, 128, 128), 2);
            }
            ++mask_it;
        }

        groundTruth.push_back(truth);

    }
    destroyWindow("Image");
}

void ImageTestSet::saveToFile(string outFile) {

    FileStorage fs(outFile, FileStorage::WRITE);

    fs << "imagePath" << string(imagePath);

    fs << "count" << count;

    fs << "maskDescriptionFile" << maskDescriptionFile;

    fs << "groundTruth" << "[";

    vector< vector< int > >::iterator it = groundTruth.begin(), it_end = groundTruth.end();

    for (; it != it_end; ++it) {
        fs << "{:";

        fs << "imageTruth" << "[:";

        vector< int >::iterator img_it = (*it).begin(), img_it_end = (*it).end();
        for (; img_it != img_it_end; ++img_it) {
            fs << (*img_it);
        }

        fs << "]";
        fs << "}";
    }

    fs << "]";

    fs << "resps" << "[";

    vector< vector< float > >::iterator it2 = resps.begin(), it2_end = resps.end();

    for (; it2 != it2_end; ++it2) {
        fs << "{:";

        fs << "imageResp" << "[:";

        vector< float >::iterator img_it = (*it2).begin(), img_it_end = (*it2).end();
        for (; img_it != img_it_end; ++img_it) {
            fs << (*img_it);
        }

        fs << "]";
        fs << "}";
    }

    fs << "]";

    fs.release();
}

void ImageTestSet::loadFromFile(string inFile) {
    pos = 0;

    FileStorage fs(inFile, FileStorage::READ);

    fs["imagePath"] >> imagePath;

    fs["count"] >> count;

    fs["maskDescriptionFile"] >> maskDescriptionFile;

    masks.clear();
    loadMaskDesc(maskDescriptionFile, masks);

    FileNode truthList = fs["groundTruth"];
    FileNodeIterator it = truthList.begin(), it_end = truthList.end();

    groundTruth.clear();
    for (; it != it_end; ++it) {
        vector< int > imageTruth;
        (*it)["imageTruth"] >> imageTruth;
        groundTruth.push_back(imageTruth);
    }

    FileNode respList = fs["resps"];
    it = respList.begin(), it_end = respList.end();

    resps.clear();
    for (; it != it_end; ++it) {
        vector< float > resp;
        (*it)["imageResp"] >> resp;
        resps.push_back(resp);
    }

    if ((resps.size() > 0) && (resps.size() != groundTruth.size())) {
        cout << "Warning, mismatched truth and respose sizes.\n";
    }

    fs.release();
}

bool ImageTestSet::hasNext() {
    return pos < count;
}

bool ImageTestSet::next() {
    pos++;
    return pos < count;
}

Mat ImageTestSet::getCurrentImage() {
    char input[256];
    sprintf(input, imagePath.c_str(), pos);
    return imread(input);
}

vector< int > ImageTestSet::getCurrentGroundTruth() {
    return groundTruth[pos];
}

vector< float > ImageTestSet::getCurrentResp() {
    return resps[pos];
}

void ImageTestSet::pushCurrentResp(vector<float> resp) {
    resps.push_back(resp);
}

int ImageTestSet::getGoodRespsForThreshold(int maskNo, float threshold) {
    int truePositives = 0, trueNegatives = 0,
            falseNegatives = 0, falsePositives = 0;

    return getGoodRespsForThreshold(maskNo, threshold,
            truePositives, trueNegatives,
            falsePositives, falseNegatives);
}

int ImageTestSet::getGoodRespsForThreshold(int maskNo, float threshold,
        int &truePositives, int &trueNegatives,
        int &falsePositives, int &falseNegatives) {
    
    int noc[] = {10, 91, 153, 235, 309, 379, 442, 521, 587, 666, 731, 810, 874, 954, 1018, 1096};
    bool jeDen = true;
    int nocptr = 1;

    pos = 0;
    truePositives = 0;
    trueNegatives = 0;
    falseNegatives = 0;
    falsePositives = 0;

    while (hasNext()) {
        if (noc[nocptr] < pos) {
            nocptr += 2;
        }
        
        vector< int > truth = getCurrentGroundTruth();
        vector< float > resp = getCurrentResp();
        
        /* Korekcia nocnych a dennych pozorovani */
        jeDen = pos <= noc[nocptr - 1];
        
        if ((maskNo == 0) && (!jeDen)) {
            trueNegatives++;
            next();
            continue;
        }
        
        if (((maskNo == 4) || (maskNo == 5)) && (jeDen)) {
            trueNegatives++;
            next();
            continue;
        }

        if ((threshold <= resp[maskNo]) && (truth[maskNo] == 1)) {
            truePositives++;
        }

        if ((threshold > resp[maskNo]) && (truth[maskNo] == 0)) {
            trueNegatives++;
        }

        if ((threshold <= resp[maskNo]) && (truth[maskNo] == 0)) {
            falsePositives++;
        }

        if ((threshold > resp[maskNo]) && (truth[maskNo] == 1)) {
            falseNegatives++;
        }

        next();
    }

    return truePositives + trueNegatives;
}
