#include "linger_model.h"
#include <stdexcept>
using std::vector;

LingerModel::LingerModel(int min_period,
                         int max_period, 
                         double period_probability_decay, 
                         double p_nucleotide_stay_background, 
                         double p_nucleotide_stay_repeat, 
                         double p_linger_nucleotide, 
                         double p_linger_blank, 
                         vector<double> nucleotides_distribution)
  : Model(min_period, 
          max_period, 
          period_probability_decay,
          p_nucleotide_stay_background,
          p_nucleotide_stay_repeat) {
  p_linger_nucleotide_ = Probability::FromP(p_linger_nucleotide);
  p_linger_blank_ = Probability::FromP(p_linger_blank);
  
  double nucleotides_distribution_sum = 0;
  for (unsigned i = 0; i < kNucleotideTypes; i++) {
    nucleotides_distribution_[i] = Probability::FromP(nucleotides_distribution[i]);
    nucleotides_distribution_sum += nucleotides_distribution[i];
  }
  if (abs(nucleotides_distribution_sum - 1) > 1e-6) {
    throw std::invalid_argument("nucleotides distribution must add up to one");
  }
  
  levels_.resize(max_period - min_period + 1);
  
  BuildHmm();
}

void LingerModel::InitializeFakeLevel() {
  fake_level_.entry_state_ = post_repeat_state_;
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    fake_level_.nucleotide_states_[label] = post_repeat_state_;
    fake_level_.nucleotide_catchup_states_[label] = post_repeat_state_;
  }
  fake_level_.blank_state_ = post_repeat_state_;
  fake_level_.blank_catchup_state_ = post_repeat_state_;
}

unsigned LingerModel::BuildLevelStates(int period) {
  Level &level = levels_[period - min_period_];
  level.entry_state_ = hmm_->AddEpsilonState();
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    level.nucleotide_states_[label] = hmm_->AddSimpleEmissionState(label);
    level.nucleotide_catchup_states_[label] = hmm_->AddEpsilonState();
  }
  level.blank_state_ = hmm_->AddSimpleEmissionState(kBlankLabel);
  level.blank_catchup_state_ = hmm_->AddEpsilonState();
  return level.entry_state_;
}

void LingerModel::BuildLevelArcs(int period) {
  InitializeFakeLevel();
  Level &level = levels_[period - min_period_];
  Level &previous_level = (period == min_period_) ? fake_level_ : levels_[period - 1 - min_period_];
  Level &next_level = (period == max_period_) ? fake_level_ : levels_[period + 1 - min_period_];
  
  vector<unsigned> nucleotide_states(kNucleotideTypes);
  for (unsigned i = 0; i < kNucleotideTypes; i++) {
    nucleotide_states[i] = level.nucleotide_states_[i];
  }
  vector<vector<pair<unsigned, Probability>>> entering_contributions(kLabelTypes);
  vector<Probability> entering_masking_weights(kNucleotideTypes, Probability::FromP(0));
  
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    entering_contributions[label].push_back({label, Probability::FromP(1)});
    entering_masking_weights[label] = Probability::FromP(1);
  }
  
  hmm_->AddLeftDependentArcs(level.entry_state_,
                             nucleotide_states,
                             entering_contributions,
                             entering_masking_weights,
                             period);
  
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    vector<Probability> p_transitions;
    vector<unsigned> target_states;
    Probability p_rest = Probability::FromP(1);
    
    target_states.push_back(next_level.nucleotide_states_[label]);
    p_transitions.push_back(p_linger_nucleotide_ * p_rest);
    p_rest *= p_linger_nucleotide_.Negate();
    
    target_states.push_back(post_repeat_state_);
    p_transitions.push_back(p_nucleotide_stay_repeat_.Negate() * p_rest);
    p_rest *= p_nucleotide_stay_repeat_;
    
    target_states.push_back(level.nucleotide_catchup_states_[label]);
    p_transitions.push_back(p_rest);
    
    hmm_->AddUnconditionalArcs(level.nucleotide_states_[label],
                               target_states,
                               p_transitions);
  }
  
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    vector<vector<pair<unsigned, Probability>> label_contributions(kLabelTypes);
    for (unsigned left_label = 0; left_label < kLabelTypes; left_label++) {
      if (left_label != label) {
         label_contributions[left_label] = {{0, Probability::FromP(1)}};
      }
    }
    
    hmm_->AddLeftDependentArcs(level.nucleotide_catchup_states_[label],
                               {level.blank_state_, previous_level_.nucleotide_catchup_states_[label]},
                               label_contributions,
                               vector<Probability> (kLabelTypes, Probability::FromP(1)),
                               period);
  }
}

void LingerModel::RegisterLevel(int period) {
  
}
  
void LingerModel::BuildBackgroundStates() {
  
}

void LingerModel::BuildBackgroundArcs() {
  
  
}

void LingerModel::RegisterBackground() {
  
  
}
