#include <alignment.h>
#include <matrix_band.h>
#include <limits>
#include <queue>
#include <fstream>
#include <cmath>
using namespace std;

MatrixBand<pair<int, int>> alignment(int rows,
                                     int cols,
                                     const Scorer &scorer,
                                     const MovementModel &movement_model,
                                     const vector<int> &row_starts,
                                     const vector<int> &row_ends,
                                     int min_lookahead,
                                     bool local) {
  vector<int> sum_appears(row_ends.size() + row_ends.back() - 1);
  int next_sum = 0;
  for (int row = 0; row < static_cast<int>(row_starts.size()); row++) {
    for (int sum = max(row + row_starts[row], next_sum); sum < row + row_ends[row]; sum++) {
      sum_appears[sum] = row;
      next_sum = sum+1;
    }
  }
  
  MatrixBand<pair<int, int>> come_from(row_starts, row_ends, {-1, -1}, {-1, -1});
  MatrixBand<double> score(row_starts, row_ends, 0, numeric_limits<double>::lowest());
  
  double default_score = local ? 0 : numeric_limits<double>::lowest();

  fprintf(stderr, "steps back computed, running dtw\n");
  int min_index_sum = row_starts[0];
  int max_index_sum = row_ends.back()-1 + row_ends.size()-1;
  for (int window_start = min_index_sum; window_start + min_lookahead <= max_index_sum || window_start == min_index_sum; window_start += min_lookahead) {
    fprintf(stderr, "window start: %d / %d\r", window_start, max_index_sum);
    for (int index_sum = window_start; index_sum <= window_start + 2*min_lookahead; index_sum++) {
      bool update_come_from = (index_sum >= window_start + min_lookahead) || (window_start == min_index_sum);
      
      if (index_sum >= static_cast<int>(sum_appears.size()) || sum_appears[index_sum] == -1) continue;
      for (int row = sum_appears[index_sum]; row < static_cast<int>(row_starts.size()) && row + row_starts[row] <= index_sum; row++) {
        int column = index_sum - row;
        
        double &current_cell = score[row][column];
        current_cell = default_score;
        if (row == 0 && column == 0 && !local) current_cell = 0;
        
        if (column - 1 >= row_starts[row]) {
          double step_score = 0;
        
          int from_col = column - 1;
          for (int row_step = 1; row_step <= movement_model.MaxMovementRows(row) && row_ends[row - row_step] > column; row_step++) {
            int from_row = row - row_step;
            if (from_row + from_col < window_start) break;
            step_score += scorer.Score(from_row, from_col, row_step == 1, true);//  signal[from_row], signal[from_col]);
            double proposed_score = score[from_row][from_col] + step_score; // * (1 + row_step) / (2.0 * row_step);
            if (proposed_score > current_cell) {
              current_cell = proposed_score;
              if (update_come_from) {
                come_from[row][column] = pair<int, int>(from_row, from_col);
              }
            }
          }
        }
        if (row - 1 >= 0 && row_ends[row-1] > column) {
          double step_score = 0;
          int from_row = row - 1;
          for (int col_step = 1; col_step <= movement_model.MaxMovementRows(column) && column - col_step >= row_starts[row]; col_step++) {
            int from_col = column - col_step;
            if (from_row + from_col < window_start) break;
            step_score += scorer.Score(from_row, from_col, true, col_step == 1);
            double proposed_score = score[from_row][from_col] + step_score; // * (1 + col_step) / (2.0 * col_step);
            if (proposed_score > current_cell) {
              current_cell = proposed_score;
              if (update_come_from) {
                come_from[row][column] = pair<int, int>(from_row, from_col);
              }
            }
          }
        }
      }
    }
  }
  return come_from;
}

vector<vector<pair<int, int>>> local_alignment(int rows,
                                               int cols,
                                               const Scorer &scorer,
                                               const MovementModel &movement_model,
                                               const vector<int> &row_starts,
                                               const vector<int> &row_ends,
                                               int min_lookahead,
                                               std::string log_filename) {

  MatrixBand<pair<int, int>> come_from = alignment(rows, 
                                                   cols, 
                                                   scorer, 
                                                   movement_model, 
                                                   row_starts, 
                                                   row_ends,
                                                   min_lookahead,
                                                   /*local=*/true);
  MatrixBand<double> score(row_starts, row_ends, 0, numeric_limits<double>::lowest());
  fprintf(stderr, "\ndtw done, logging\n");
  
  if (log_filename != "") {
    ofstream out(log_filename.c_str(), ios_base::binary);
    out << row_starts.size() << "\n";
    for (int row = 0; row < static_cast<int>(row_starts.size()); row++) {
      out << row_starts[row] << " " << row_ends[row] << "\n";
    }
    for (int y = 0; y < static_cast<int>(row_starts.size()); y++) {
      for (int x = row_starts[y]; x < row_ends[y]; x++) {
        if (come_from[y][x] == pair<int, int>(-1, -1)) {
          out.put(static_cast<int8_t>(-1));
          out.put(static_cast<int8_t>(-1));
        }
        else {
          if (x - come_from[y][x].second > 127 || y - come_from[y][x].first > 127) {
            fprintf(stderr, "warning: a step too long for encoding into one byte, truncating\n");
          }
          int8_t x_step = min(x - come_from[y][x].second, 127);
          int8_t y_step = min(y - come_from[y][x].first, 127);
          out.put(x_step);
          out.put(y_step);
        }
      }
    }
  }
  
  int min_index_sum = row_starts[0];
  int max_index_sum = row_ends.back()-1 + row_ends.size()-1;
  vector<int> sum_appears(max_index_sum + 1);
  int next_sum = 0;
  for (int row = 0; row < static_cast<int>(row_starts.size()); row++) {
    for (int sum = max(row + row_starts[row], next_sum); sum < row + row_ends[row]; sum++) {
      sum_appears[sum] = row;
      next_sum = sum+1;
    }
  }
  
  
  fprintf(stderr, "getting greedy paths\n");
  vector<vector<pair<int, int>>> result;
  queue<pair<int, int>> unprocessed;
  unprocessed.emplace(min_index_sum, max_index_sum);
  while (!unprocessed.empty()) {
    pair<int, int> interval = unprocessed.front();
    unprocessed.pop();
    fprintf(stderr, "interval [%d %d]              \r", interval.first, interval.second);
    pair<double, pair<int, int>> best = {0, {-1, -1}};
    for (int index_sum = interval.first; index_sum <= interval.second; index_sum++) {
      if (index_sum % 1000 == 0) fprintf(stderr, "index_sum: %d/%d/%d                    \r", interval.first, index_sum, interval.second);
      for (int row = sum_appears[index_sum]; row < static_cast<int>(row_starts.size()) && row_starts[row] + row <= index_sum; row++) {
        int column = index_sum - row;
        score[row][column] = 0;
        pair<int, int> from = come_from[row][column];
        if (from != pair<int,int>(-1, -1)) {
          int from_row = from.first;
          int from_col = from.second;
          double step_score = 0;
          for (int r = from_row; r < row; r++) {
            for (int c = from_col; c < column; c++) {
              step_score += scorer.Score(r, c, r == from_row, c == from_col);
            }
          }
          score[row][column] = score[from] + step_score;// * (row - from_row + column - from_col) / ((row-from_row)*(column-from_col)*2.0);
        }
        best = max(best, {score[row][column], {row, column}});
      }
    }
    if (best.first <= 0) continue;
    
    pair<int, int> position = best.second;
    vector<pair<int, int>> path(1, position);
    int high_sum = position.first + position.second;
    while (come_from[position] != pair<int, int>(-1, -1)) {
      position = come_from[position];
      path.push_back(position);
    }
    int low_sum = position.first + position.second;
    
    for (int index_sum = low_sum+1; index_sum <= high_sum; index_sum++) {
      for (int row = sum_appears[index_sum]; row < static_cast<int>(row_starts.size()) && row_starts[row] + row <= index_sum; row++) {
        int column = index_sum - row;
        come_from[row][column] = {-1, -1};
        score[row][column] = 0;
      }
    }
    
    if (interval.first < low_sum) {
      unprocessed.emplace(interval.first, low_sum);
    }
    if (high_sum < interval.second) {
      unprocessed.emplace(high_sum, interval.second);
    }
        
    reverse(path.begin(), path.end());
    if (path.back().first >= path[0].second) {
      result.push_back(path);
    }
  }
  return result;
}
