#include <stdexcept>
#include "full_model.h"
using std::vector;
using std::pair;

FullModel::FullModel(int min_period, 
                     int max_period, 
                     double period_probability_decay, 
                     double p_nucleotide_stay_background,
                     double p_nucleotide_stay_repeat,
                     double p_nucleotide_is_indel,
                     double p_linger_nucleotide,
                     double p_linger_blank,
                     vector<double> nucleotides_distribution,
                     vector<vector<double>> match_matrix) 
  : Model(min_period,
          max_period, 
          period_probability_decay, 
          p_nucleotide_stay_background, 
          p_nucleotide_stay_repeat) {
  
  match_matrix_.resize(kNucleotideTypes, vector<Probability>(kNucleotideTypes));
  for (unsigned i = 0; i < kNucleotideTypes; i++) {
    for (unsigned j = 0; j < kNucleotideTypes; j++) {
      match_matrix_[i][j] = Probability::FromP(match_matrix[i][j]);
    }
  }
  nucleotides_distribution_.resize(kNucleotideTypes);
  double nucleotides_distribution_sum = 0;
  for (unsigned i = 0; i < kNucleotideTypes; i++) {
    nucleotides_distribution_[i] = Probability::FromP(nucleotides_distribution[i]);
    nucleotides_distribution_sum += nucleotides_distribution[i];
  }
  if (abs(nucleotides_distribution_sum - 1) > 1e-6) {
    throw std::invalid_argument("nucleotides distribution doesn't add up to one");
  }
  
  
  p_nucleotide_is_indel_ = Probability::FromP(p_nucleotide_is_indel);
  p_linger_nucleotide_ = Probability::FromP(p_linger_nucleotide);
  p_linger_blank_ = Probability::FromP(p_linger_blank);
  
  double expected_nucleotide_length = 1 / p_linger_blank;
  double expected_blank_length = 1 / p_linger_blank - 1;
  p_nucleotide_ = Probability::FromP(expected_nucleotide_length / (expected_nucleotide_length + expected_blank_length));
  p_blank_ = Probability::FromP(expected_blank_length / (expected_nucleotide_length + expected_blank_length));
  
  levels_.resize(max_period - min_period + 1);
  for (int period = min_period; period <= max_period; period++) {
    levels_[period - min_period].period_ = period;
  }
  
  BuildHmm();
}

void FullModel::BuildBackgroundStates() {
  background_entry_state_ = hmm_->AddEpsilonState();
  vector<Probability> p_prior_(kLabelTypes);
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    p_prior_[label] = nucleotides_distribution_[label] * p_nucleotide_;
  }
  p_prior_[kBlankLabel] = p_blank_;
  for (unsigned label = 0; label < kLabelTypes; label++) {
    background_generating_states_[label] = hmm_->AddSimpleEmissionState(label, p_prior_[label]);
  }
  post_repeat_state_ = hmm_->AddEpsilonState();
}

void FullModel::BuildBackgroundArcs() {
  vector<unsigned> entry_targets;
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    entry_targets.push_back(background_generating_states_[label]);
  }
  hmm_->AddUnconditionalArcs(background_entry_state_,
                             entry_targets, 
                             nucleotides_distribution_);
  hmm_->AddUnconditionalArcs(background_generating_states_[kBlankLabel], 
                             {background_generating_states_[kBlankLabel], background_entry_state_, repeat_entry_state_}, 
                             {p_linger_blank_, 
                              p_linger_blank_.Negate() * p_nucleotide_stay_background_, 
                              p_linger_blank_.Negate() * p_nucleotide_stay_background_.Negate()});
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    vector<unsigned> target_states;
    vector<Probability> p_transitions;
    Probability p_rest = Probability::FromP(1);
    
    target_states.push_back(background_generating_states_[label]);
    p_transitions.push_back(p_linger_nucleotide_ * p_rest);
    p_rest *= p_linger_nucleotide_.Negate();
    
    target_states.push_back(background_generating_states_[kBlankLabel]);
    p_transitions.push_back(p_linger_blank_ * p_rest);
    p_rest *= p_linger_blank_.Negate();
    
    target_states.push_back(background_entry_state_);
    p_transitions.push_back(p_nucleotide_stay_background_ * p_rest);
    p_rest *= p_nucleotide_stay_background_.Negate();
    
    target_states.push_back(repeat_entry_state_);
    p_transitions.push_back(p_rest);
    
    hmm_->AddUnconditionalArcs(background_generating_states_[label], 
                               target_states,
                               p_transitions);
  }
  hmm_->AddUnconditionalArcs(post_repeat_state_,
                             {background_entry_state_, repeat_entry_state_}, 
                             {p_nucleotide_stay_background_, p_nucleotide_stay_background_.Negate()});
}

void FullModel::RegisterBackground() {
  for (unsigned label = 0; label < kLabelTypes; label++) {
    is_state_background_[background_generating_states_[label]] = true;
  }
  is_state_background_[background_entry_state_] = true;
  is_state_epsilon_[background_entry_state_] = true;
}

void FullModel::InitializeFakeLevel() {
  fake_level_.main_entry_state_ = post_repeat_state_;
  for (unsigned left_label = 0; left_label < kNucleotideTypes; left_label++)  {
    for (unsigned right_label = 0; right_label < kLabelTypes; right_label++) {
      fake_level_.detailed_entry_states_[left_label][right_label] = post_repeat_state_;
      fake_level_.generative_states_[left_label][right_label] = post_repeat_state_;
    }
  }
  for (unsigned left_label = 0; left_label < kLabelTypes; left_label++) {
    for (unsigned right_label = 0; right_label < kLabelTypes; right_label++) {
      fake_level_.catchup_states_[left_label][right_label] = post_repeat_state_;
    }
  }
}

void FullModel::BuildMainEntryArcs(Level &level) {
  vector<unsigned> target_states(kNucleotideTypes);
  vector<vector<pair<unsigned, Probability>>> label_contributions(kLabelTypes);
  vector<Probability> label_masking_weights(kLabelTypes, Probability::FromP(0));
  
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    target_states[label] = level.detailed_entry_states_[label][kBlankLabel];
  }
  
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    label_masking_weights[label] = Probability::FromP(1);
    label_contributions[label] = {{label, Probability::FromP(1)}};
  }
  
  hmm_->AddLeftDependentArcs(level.main_entry_state_, 
                            target_states,
                            label_contributions,
                            label_masking_weights,
                            level.period_);
}

void FullModel::BuildDetailedEntryArcs(Level &level,
                                       unsigned left_label, 
                                       unsigned forbidden_right) {
  vector<unsigned> target_states;
  vector<Probability> p_transitions;
  Probability transitions_sum = Probability::FromP(0);
  for (unsigned right_label = 0; right_label < kNucleotideTypes; right_label++) {
    if (right_label != forbidden_right) {
      target_states.push_back(level.generative_states_[left_label][right_label]);
      p_transitions.push_back(match_matrix_[left_label][right_label]);
      transitions_sum += match_matrix_[left_label][right_label];
    }
  }
  if (transitions_sum > Probability::FromP(0)) {
    for (Probability &p : p_transitions) {
      p /= transitions_sum;
    }
  }
  
  hmm_->AddUnconditionalArcs(level.detailed_entry_states_[left_label][forbidden_right],
                             target_states,
                             p_transitions);
}

void FullModel::BuildNucleotideGenerativeArcs(Level &level, 
                                              const Level &next_level,
                                              unsigned left_label,
                                              unsigned right_label) {
  unsigned state = level.generative_states_[left_label][right_label];
  vector<unsigned> target_states;
  vector<Probability> p_transitions;
  vector<unsigned> advancing_arc_indices;
  Probability p_rest = Probability::FromP(1);
  
  target_states.push_back(next_level.generative_states_[left_label][right_label]);
  p_transitions.push_back(p_linger_nucleotide_ * p_rest);
  p_rest *= p_linger_nucleotide_.Negate();
  
  target_states.push_back(next_level.generative_states_[left_label][kBlankLabel]);
  p_transitions.push_back(p_linger_blank_ * p_rest);
  p_rest *= p_linger_blank_.Negate();
  
  advancing_arc_indices.push_back(p_transitions.size());
  target_states.push_back(post_repeat_state_);
  p_transitions.push_back(p_nucleotide_stay_repeat_.Negate() * p_rest);
  p_rest *= p_nucleotide_stay_repeat_;
  
  Probability normalization_factor = nucleotides_distribution_[right_label].Negate();
  for (unsigned inserted_label = 0; inserted_label < kNucleotideTypes; inserted_label++) {
    if (inserted_label == right_label) continue;
    target_states.push_back(next_level.generative_states_[left_label][inserted_label]);
    Probability p_insert_this = nucleotides_distribution_[inserted_label] / normalization_factor;
    p_transitions.push_back(p_nucleotide_is_indel_ * p_rest * p_insert_this);
  }
  p_rest *= p_nucleotide_is_indel_.Negate();
  
  advancing_arc_indices.push_back(p_transitions.size());
  target_states.push_back(level.catchup_states_[left_label][right_label]);
  p_transitions.push_back(p_rest);
  
  auto arc_ids = hmm_->AddUnconditionalArcs(state, target_states, p_transitions);
  for (unsigned index : advancing_arc_indices) {
    level.advancing_arcs_.push_back(arc_ids[index]);
    level.advancement_periods_.push_back(level.period_ + 1);
  }
}

void FullModel::BuildBlankGenerativeArcs(Level &level, 
                                         const Level &next_level,
                                         unsigned left_label) {
  unsigned state = level.generative_states_[left_label][kBlankLabel];
  vector<unsigned> target_states;
  vector<Probability> p_transitions;
  vector<unsigned> advancing_arc_indices;
  Probability p_rest = Probability::FromP(1);
  
  target_states.push_back(next_level.generative_states_[left_label][kBlankLabel]);
  p_transitions.push_back(p_linger_blank_ * p_rest);
  p_rest *= p_linger_blank_.Negate();
  
  advancing_arc_indices.push_back(p_transitions.size());
  target_states.push_back(post_repeat_state_);
  p_transitions.push_back(p_nucleotide_stay_repeat_.Negate() * p_rest);
  p_rest *= p_nucleotide_stay_repeat_;
  
  for (unsigned inserted_label = 0; inserted_label < kNucleotideTypes; inserted_label++) {
    target_states.push_back(next_level.generative_states_[left_label][inserted_label]);
    p_transitions.push_back(p_nucleotide_is_indel_ * p_rest * nucleotides_distribution_[inserted_label]);
  }
  p_rest *= p_nucleotide_is_indel_.Negate();
  
  advancing_arc_indices.push_back(p_transitions.size());
  target_states.push_back(level.catchup_states_[left_label][kBlankLabel]);
  p_transitions.push_back(p_rest);
  
  auto arc_ids = hmm_->AddUnconditionalArcs(state, target_states, p_transitions);
  for (unsigned index : advancing_arc_indices) {
    level.advancing_arcs_.push_back(arc_ids[index]);
    level.advancement_periods_.push_back(level.period_ + 1);
  }
}

void FullModel::BuildNucleotideCatchupArcs(Level &level, 
                                           const Level &previous_level, 
                                           unsigned left_label, 
                                           unsigned forbidden_right) {
  unsigned state = level.catchup_states_[left_label][forbidden_right];
  vector<Probability> label_masking_weights(kLabelTypes, Probability::FromP(1));
  vector<vector<pair<unsigned, Probability>>> label_contributions(kLabelTypes);
  vector<unsigned> target_states;
  vector<unsigned> advancing_arc_indices;
  unsigned current_index = 0;
  
  target_states.push_back(previous_level.catchup_states_[left_label][forbidden_right]);
  label_contributions[left_label].push_back({current_index, Probability::FromP(1)});
  advancing_arc_indices.push_back(current_index);
  current_index++;
  
  target_states.push_back(previous_level.catchup_states_[kBlankLabel][forbidden_right]);
  label_contributions[kBlankLabel].push_back({current_index, Probability::FromP(1)});
  advancing_arc_indices.push_back(current_index);
  current_index++;
  
  for (unsigned deleted_label = 0; deleted_label < kNucleotideTypes; deleted_label++) {
    if (deleted_label == left_label) continue;
    target_states.push_back(previous_level.catchup_states_[deleted_label][forbidden_right]);
    label_contributions[deleted_label].push_back({current_index, p_nucleotide_is_indel_});
    advancing_arc_indices.push_back(current_index);
    current_index++;
  }
  
  for (unsigned next_label = 0; next_label < kNucleotideTypes; next_label++) {
    if (next_label == left_label) continue;
    target_states.push_back(level.detailed_entry_states_[next_label][forbidden_right]);
    label_contributions[next_label].push_back({current_index, p_nucleotide_is_indel_.Negate()});
    current_index++;
  }
  
  auto arc_ids = hmm_->AddLeftDependentArcs(state,
                                            target_states,
                                            label_contributions,
                                            label_masking_weights,
                                            level.period_);
  
  for (unsigned index : advancing_arc_indices) {
    level.advancing_arcs_.push_back(arc_ids[index]);
    level.advancement_periods_.push_back(level.period_);
  }
}

void FullModel::BuildBlankCatchupArcs(Level &level, 
                                      const Level &previous_level, 
                                      unsigned forbidden_right) {
  unsigned state = level.catchup_states_[kBlankLabel][forbidden_right];
  vector<Probability> label_masking_weights(5, Probability::FromP(1));
  vector<vector<pair<unsigned, Probability>>> label_contributions(kLabelTypes);
  vector<unsigned> target_states;
  vector<unsigned> advancing_arc_indices;
  unsigned current_index = 0;
  
  target_states.push_back(previous_level.catchup_states_[kBlankLabel][forbidden_right]);
  label_contributions[kBlankLabel].push_back({current_index, Probability::FromP(1)});
  advancing_arc_indices.push_back(current_index);
  current_index++;
  
  for (unsigned label = 0; label < kNucleotideTypes; label++) {
    target_states.push_back(previous_level.catchup_states_[label][forbidden_right]);
    label_contributions[label].push_back({current_index, p_nucleotide_is_indel_});
    advancing_arc_indices.push_back(current_index);
    current_index++;
    
    target_states.push_back(level.detailed_entry_states_[label][forbidden_right]);
    label_contributions[label].push_back({current_index, p_nucleotide_is_indel_.Negate()});
    current_index++;
  }
  
  auto arc_ids = hmm_->AddLeftDependentArcs(state, 
                                            target_states, 
                                            label_contributions, 
                                            label_masking_weights, 
                                            level.period_);
  
  for (unsigned index : advancing_arc_indices) {
    level.advancing_arcs_.push_back(arc_ids[index]);
    level.advancement_periods_.push_back(level.period_);
  }
}


unsigned FullModel::BuildLevelStates(int period) {
  Level &level = levels_[period - min_period_];
  level.main_entry_state_ = hmm_->AddEpsilonState();
  for (unsigned left_label = 0; left_label < kNucleotideTypes; left_label++) {
    for (unsigned right_label = 0; right_label < kLabelTypes; right_label++) {
      level.detailed_entry_states_[left_label][right_label] = hmm_->AddEpsilonState();
      level.generative_states_[left_label][right_label] = hmm_->AddSimpleEmissionState(right_label);
    }
  }
  
  for (unsigned left_label = 0; left_label < kLabelTypes; left_label++) {
    for (unsigned right_label = 0; right_label < kLabelTypes; right_label++) {
      level.catchup_states_[left_label][right_label] = hmm_->AddEpsilonState();
    }
  }
  return level.main_entry_state_;
}

void FullModel::BuildLevelArcs(int period) {
  InitializeFakeLevel();
  
  Level &level = levels_[period - min_period_];
  Level &previous_level = (period == min_period_) ? fake_level_ : levels_[period - 1 - min_period_];
  Level &next_level = (period == max_period_) ? fake_level_ : levels_[period + 1 - min_period_];
  
  BuildMainEntryArcs(level);
  for (unsigned left_label = 0; left_label < kNucleotideTypes; left_label++) {
    for (unsigned forbidden_right = 0; forbidden_right < kLabelTypes; forbidden_right++) {
      BuildDetailedEntryArcs(level, left_label, forbidden_right);
      BuildNucleotideCatchupArcs(level, previous_level, left_label, forbidden_right);
    }
    
    for (unsigned right_label = 0; right_label < kNucleotideTypes; right_label++) {
      BuildNucleotideGenerativeArcs(level, next_level, left_label, right_label);
    }
    BuildBlankGenerativeArcs(level, next_level, left_label);
  }
  for (unsigned forbidden_right = 0; forbidden_right < kLabelTypes; forbidden_right++) {
    BuildBlankCatchupArcs(level, previous_level, forbidden_right);
  }
}

void FullModel::RegisterLevel(int period) {
  Level &level = levels_[period - min_period_];
  is_state_repeat_[level.main_entry_state_] = true;
  is_state_epsilon_[level.main_entry_state_] = true;
  for (unsigned left_label = 0; left_label < kLabelTypes; left_label++) {
    for (unsigned right_label = 0; right_label < kLabelTypes; right_label++) {
      if (left_label < kNucleotideTypes) {
        is_state_repeat_[level.detailed_entry_states_[left_label][right_label]] = true;
        is_state_epsilon_[level.detailed_entry_states_[left_label][right_label]] = true;
        is_state_repeat_[level.generative_states_[left_label][right_label]] = true;
      }
      is_state_repeat_[level.catchup_states_[left_label][right_label]] = true;
      is_state_epsilon_[level.catchup_states_[left_label][right_label]] = true;
    }
  }
  
  for (unsigned i = 0; i < level.advancing_arcs_.size(); i++) {
    is_arc_advancing_left_[level.advancing_arcs_[i]] = true;
    advancing_arc_period_[level.advancing_arcs_[i]] = level.advancement_periods_[i];
  }
}
