#include <stdexcept>
#include "simple_model.h"
#include "hmm.h"
using std::vector;
using std::pair;

SimpleModel::SimpleModel(int min_period, 
                         int max_period,
                         double period_probability_decay,
                         vector<double> label_prior_probabilities,
                         double p_nucleotide_stay_background,
                         double p_nucleotide_stay_repeat,
                         double p_nucleotide_is_indel,
                         vector<vector<double>> match_matrix)
  : Model(min_period,
          max_period, 
          period_probability_decay, 
          p_nucleotide_stay_background, 
          p_nucleotide_stay_repeat) {
  
  double sum = 0;
  for (double p : label_prior_probabilities) {
    sum += p;
  }
  if(abs(sum - 1) > 1e-6) {
    throw std::invalid_argument("label distribution doesn't add up to one");
  }
  for (double p : label_prior_probabilities) {
    label_prior_distribution_.push_back(Probability::FromP(p));
  }

  p_nucleotide_is_indel_ = Probability::FromP(p_nucleotide_is_indel);  
  p_blank_ = label_prior_distribution_[kBlankLabel];
  p_nucleotide_ = p_blank_.Negate();
  
  match_emission_matrix_.resize(kNucleotideTypes, vector<Probability>(kNucleotideTypes));
  for (unsigned i = 0; i < kNucleotideTypes; i++) {
    for (unsigned j = 0; j < kNucleotideTypes; j++) {
      match_emission_matrix_[i][j] = Probability::FromP(match_matrix[i][j]);
    }
  }
  ComputeBlankEmissions();
  
  levels_.resize(max_period - min_period + 1);
  BuildHmm();
}

void SimpleModel::ComputeBlankEmissions() {
  blank_emissions_.resize(kLabelTypes);
  p_gap_or_insertion_ = Probability::FromP(0);
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    blank_emissions_[label] = label_prior_distribution_[label] * p_nucleotide_is_indel_;
    p_gap_or_insertion_ += blank_emissions_[label];
  }
  blank_emissions_[kBlankLabel] = label_prior_distribution_[kBlankLabel];
  p_gap_or_insertion_ += blank_emissions_[kBlankLabel];
  
  for (unsigned label = 0; label < kLabelTypes; label++) {
    blank_emissions_[label] /= p_gap_or_insertion_;
  }
}

void SimpleModel::BuildBackgroundStates() {
  vector<Probability> background_nucleotide_distribution(kLabelTypes, Probability::FromP(0));
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    background_nucleotide_distribution[label] = label_prior_distribution_[label] / p_nucleotide_;
  }
  vector<Probability> background_blank_distribution(kLabelTypes, Probability::FromP(0));
  background_blank_distribution[kBlankLabel] = Probability::FromP(1);
  background_nucleotide_state_ = hmm_->AddEmissionState(background_nucleotide_distribution, p_nucleotide_);
  background_blank_state_ = hmm_->AddEmissionState(background_blank_distribution, p_blank_);
  background_entry_state_ = hmm_->AddEpsilonState();
  post_repeat_state_ = hmm_->AddEpsilonState();
}

void SimpleModel::BuildBackgroundArcs() {
  Probability p_stay_background = p_blank_ + p_nucleotide_ * p_nucleotide_stay_background_;
  for (unsigned background_state : {background_nucleotide_state_, background_blank_state_}) {
    hmm_->AddUnconditionalArcs(background_state, 
                               {background_entry_state_, repeat_entry_state_},
                               {p_stay_background, p_stay_background.Negate()});
  }
  hmm_->AddUnconditionalArcs(background_entry_state_, 
                             {background_nucleotide_state_, background_blank_state_}, 
                             {p_nucleotide_, p_blank_});
  hmm_->AddUnconditionalArcs(post_repeat_state_, 
                             {background_entry_state_, repeat_entry_state_},
                             {p_stay_background, p_stay_background.Negate()});
}

void SimpleModel::RegisterBackground() {
  is_state_background_[background_nucleotide_state_] = true;
  is_state_background_[background_blank_state_] = true;
  is_state_epsilon_[background_entry_state_] = true;
}

unsigned SimpleModel::BuildLevelStates(int period) {
  Level &level = levels_[period - min_period_];
  level.match_state_ = hmm_->AddMatchState(match_emission_matrix_, period);
  level.gap_star_state_ = hmm_->AddEpsilonState();
  level.gap_state_ = hmm_->AddEmissionState(blank_emissions_);
  level.catchup_state_ = hmm_->AddEpsilonState();
  return level.match_state_;
}

void SimpleModel::BuildLevelArcs(int period) {
  Level &level = levels_[period - min_period_];
  unsigned previous_catchup_state_;
  if (period == min_period_) {
    previous_catchup_state_ = post_repeat_state_;
  }
  else {
    previous_catchup_state_ = levels_[period - min_period_ - 1].catchup_state_;
  }
  unsigned next_gap_star_state_;
  if (period == max_period_) {
    next_gap_star_state_ = post_repeat_state_;
  }
  else {
    next_gap_star_state_ = levels_[period - min_period_ + 1].gap_star_state_;
  }
  
  level.match_arcs_ = hmm_->AddUnconditionalArcs(level.match_state_, 
                                                 {post_repeat_state_, level.gap_star_state_}, 
                                                 {p_nucleotide_stay_repeat_.Negate(), p_nucleotide_stay_repeat_});
  level.gap_star_arcs_ = hmm_->AddUnconditionalArcs(level.gap_star_state_, 
                                                    {level.gap_state_, level.catchup_state_},
                                                    {p_gap_or_insertion_, p_gap_or_insertion_.Negate()});
  level.gap_arcs_ = hmm_->AddUnconditionalArcs(level.gap_state_, 
                                               {next_gap_star_state_}, 
                                               {Probability::FromP(1)});
  vector<vector<pair<unsigned, Probability>>> label_contributions(kLabelTypes);
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    label_contributions[label] = {{0, p_nucleotide_is_indel_.Negate()},
                                  {1, p_nucleotide_is_indel_}};
  }
  label_contributions[kBlankLabel] = {{1, Probability::FromP(1)}};
  level.catchup_arcs_ = hmm_->AddLeftDependentArcs(level.catchup_state_, 
                                                  {level.match_state_, previous_catchup_state_},
                                                  label_contributions,
                                                  vector<Probability> (kLabelTypes, Probability::FromP(1)),
                                                  period);
}

void SimpleModel::RegisterLevel(int period) {
  Level &level = levels_[period - min_period_];
  is_state_repeat_[level.match_state_] = true;
  is_state_repeat_[level.gap_star_state_] = true;
  is_state_repeat_[level.gap_state_] = true;
  is_state_repeat_[level.catchup_state_] = true;
  
  is_state_epsilon_[level.gap_star_state_] = true;
  is_state_epsilon_[level.catchup_state_] = true;
  
  for (unsigned arc_id : level.match_arcs_) {
    is_arc_advancing_left_[arc_id] = true;
    advancing_arc_period_[arc_id] = period+1;
  }
  is_arc_advancing_left_[level.catchup_arcs_[1]] = true;
  advancing_arc_period_[level.catchup_arcs_[1]] = period;
}
