#include <hmm.h>
#include <queue>
#include <cstdio>
#include <cmath>
#include <cassert>
using std::vector;
using std::pair;
using std::queue;

Hmm::Hmm()
  : epsilons_precomputed_(false) {}

Hmm::State::State(Probability starting_p, Probability ending_p)
  : starting_p_(starting_p),
    ending_p_(ending_p) {}

Hmm::State::~State() = default;

Probability Hmm::State::GetStartingP() {
  return starting_p_;
}

Probability Hmm::State::GetEndingP() {
  return ending_p_;
}

Hmm::EmissionState::EmissionState(Probability starting_p,
                                  Probability ending_p,
                                  vector<Probability> distribution)
  : State(starting_p, ending_p),
    emission_distribution_(distribution) {}

Probability Hmm::EmissionState::EmissionProbability(const CtcPredictions &ctc_predictions,
                                                    int position) {
  Probability result = Probability::FromP(0);
  for (unsigned label = 0; label < kLabelTypes; label++) {
    result += ctc_predictions.conditionals_[position][label] * emission_distribution_[label];
  }
  return result;
}

Hmm::SimpleEmissionState::SimpleEmissionState(Probability starting_p, 
                                              Probability ending_p, 
                                              unsigned emitted_label)
  : State(starting_p, ending_p),
    emitted_label_(emitted_label) {}

Probability Hmm::SimpleEmissionState::EmissionProbability(const CtcPredictions &ctc_predictions, 
                                                          int position) {
  return ctc_predictions.conditionals_[position][emitted_label_];
}


Hmm::EpsilonState::EpsilonState(Probability starting_p, Probability ending_p)
  : State(starting_p, ending_p) {}

Probability Hmm::EpsilonState::EmissionProbability(const CtcPredictions &ctc_predictions,
                                                   int position) {
  return Probability::FromP(0);
}

Hmm::MatchingNucleotideState::MatchingNucleotideState(Probability starting_p, 
                                                      Probability ending_p,
                                                      vector<vector<Probability>> emission_matrix,
                                                      int period) 
  : State(starting_p, ending_p),
    emission_matrix_(emission_matrix),
    period_(period) {}

Probability Hmm::MatchingNucleotideState::EmissionProbability(const CtcPredictions &ctc_predictions,
                                                              int position) {
  int left_position = position - period_;
  if (left_position < 0) return Probability::FromP(0);
  Probability result = Probability::FromP(0);
  Probability normalization_factor = ctc_predictions.predictions_[left_position][kBlankLabel].Negate();
  for (unsigned left_label = 0; left_label < kNucleotideTypes; left_label++) {
    Probability sum = Probability::FromP(0);
    for (unsigned right_label = 0; right_label < kNucleotideTypes; right_label++) {
      sum += emission_matrix_[left_label][right_label] * ctc_predictions.conditionals_[position][right_label];
    }
    result += ctc_predictions.predictions_[left_position][left_label] / normalization_factor * sum;
  }
  return result;
}

Hmm::TransitionComputer::~TransitionComputer() {}

Hmm::LeftDependentTransitions::LeftDependentTransitions(vector<unsigned> arc_ids,
                                                        vector<vector<pair<unsigned, Probability>>> label_contributions,
                                                        vector<Probability> label_masking_weights,
                                                        int period)
  : arc_ids_(arc_ids), 
    label_contributions_(label_contributions),
    label_masking_weights_(label_masking_weights),
    period_(period) {}

vector<pair<unsigned, Probability>> Hmm::LeftDependentTransitions::Compute(const CtcPredictions &ctc_predictions,
                                                                           int position) {
  int left_position = position - period_;
  if (left_position < 0) {
    return {};
  }
  
  vector<pair<unsigned, Probability>> result(arc_ids_.size());
  for (unsigned i = 0; i < arc_ids_.size(); i++) {
    result[i].first = arc_ids_[i];
    result[i].second = Probability::FromP(0);
  }
  
  Probability labels_weight_sum = Probability::FromP(0);
  for (unsigned label = 0; label < kLabelTypes; label++) {
    Probability label_weight = ctc_predictions.predictions_[left_position][label] * label_masking_weights_[label];
    labels_weight_sum += label_weight;
    for (const pair<unsigned, Probability> effect : label_contributions_[label]) {
      result[effect.first].second += label_weight * effect.second;
    }
  }
  if (labels_weight_sum > Probability::FromP(0)) {
    for (pair<unsigned, Probability> p : result) {
      p.second /= labels_weight_sum;
    }
  }
  return result;
}

Hmm::RepeatEnteringTransitions::RepeatEnteringTransitions(vector<unsigned> arc_ids,
                                                          vector<Probability> default_probabilities,
                                                          vector<int> periods)
  : arc_ids_(arc_ids),
    default_probabilities_(default_probabilities),
    periods_(periods) {}

vector<pair<unsigned, Probability>> Hmm::RepeatEnteringTransitions::Compute(const CtcPredictions &ctc_predictions,
                                                                            int position) {
  vector<pair<unsigned, Probability>> result(arc_ids_.size());
  Probability sum = Probability::FromP(0);
  for (unsigned i = 0; i < arc_ids_.size(); i++) {
    result[i].first = arc_ids_[i];
    int left_position = position - periods_[i];
    if (left_position < 0) {
      result[i].second = Probability::FromP(0);
      continue;
    }
    Probability p_nucleotide = ctc_predictions.predictions_[left_position][kBlankLabel].Negate();
    result[i].second += default_probabilities_[i] * p_nucleotide;
    sum += result[i].second;
  }
  if (sum > Probability::FromP(0)) {
    for (pair<unsigned, Probability> &p : result) {
      p.second /= sum;
    }
  }
  return result;
}


unsigned Hmm::AddState(State *state) {
  unsigned id = states_.size();
  states_.push_back(state);
  arcs_from_.emplace_back(0);
  arcs_to_.emplace_back(0);
  return id;
}

unsigned Hmm::AddEmissionState(vector<Probability> emission_distribution,
                               Probability starting_p, 
                               Probability ending_p) {
  State *state = new EmissionState(starting_p, ending_p, emission_distribution);
  unsigned id = AddState(state);
  normal_states_ids_.push_back(id);
  return id;
}

unsigned Hmm::AddSimpleEmissionState(unsigned emitted_label, 
                                     Probability starting_p, 
                                     Probability ending_p) {
  State *state = new SimpleEmissionState(starting_p, ending_p, emitted_label);
  unsigned id = AddState(state);
  normal_states_ids_.push_back(id);
  return id;
}

unsigned Hmm::AddEpsilonState(Probability starting_p, 
                              Probability ending_p) {
  State *state = new EpsilonState(starting_p, ending_p);
  unsigned id = AddState(state);
  epsilon_states_ids_.push_back(id);
  return id;
}

unsigned Hmm::AddMatchState(vector<vector<Probability>> emission_matrix,
                            int period,
                            Probability starting_p,
                            Probability ending_p) {
  State *state = new MatchingNucleotideState(starting_p, ending_p, emission_matrix, period);
  unsigned id = AddState(state);
  normal_states_ids_.push_back(id);
  return id;
}
  
unsigned Hmm::AddArc(unsigned from, unsigned to) {
  unsigned arc_id = static_transition_probabilities_.size();
  static_transition_probabilities_.push_back(Probability::FromP(0));
  arcs_from_[from].emplace_back(to, arc_id);
  arcs_to_[to].emplace_back(from, arc_id);
  return arc_id;
}

vector<unsigned> Hmm::AddArcs(unsigned from, const vector<unsigned> &to) {
  vector<unsigned> result(to.size());
  for (unsigned i = 0; i < to.size(); i++) {
    result[i] = AddArc(from, to[i]);
  }
  return result;
}
vector<unsigned> Hmm::AddUnconditionalArcs(unsigned from_state, 
                                           vector<unsigned> to_states, 
                                           vector<Probability> p_transitions) {
  vector<unsigned> result = AddArcs(from_state, to_states);
  Probability p_sum = Probability::FromP(0);
  for (Probability p : p_transitions) {
    p_sum += p;
  }
  
  for (unsigned i = 0; i < result.size(); i++) {
    static_transition_probabilities_[result[i]] = p_transitions[i];
    if (p_sum > Probability::FromP(0)) {
      static_transition_probabilities_[result[i]] /= p_sum;
    }
  }
  return result;
}

vector<unsigned> Hmm::AddLeftDependentArcs(unsigned from_state,
                                           vector<unsigned> to_states,
                                           vector<vector<pair<unsigned, Probability>>> label_contributions,
                                           vector<Probability> label_masking_weights,
                                           int period) {
  for (unsigned label = 0; label < kLabelTypes; label++) {
    Probability sum = Probability::FromP(0);
    for (const auto &cont_pair : label_contributions[label]) {
      sum += cont_pair.second;
    }
    if (sum > Probability::FromP(0)) {
      for (auto &cont_pair : label_contributions[label]) {
        cont_pair.second /= sum;
      }
    }
  }
  vector<unsigned> result = AddArcs(from_state, to_states);
  transition_computers_.push_back(new LeftDependentTransitions(result, label_contributions, label_masking_weights, period));
  return result;
}

vector<unsigned> Hmm::AddRepeatEnteringArcs(unsigned from_state,
                                            vector<unsigned> to_states,
                                            vector<Probability> default_probabilities,
                                            vector<int> periods) {
  vector<unsigned> result = AddArcs(from_state, to_states);
  transition_computers_.push_back(new RepeatEnteringTransitions(result, default_probabilities, periods));
  return result;
}

vector<Probability> Hmm::TransitionProbabilities(const CtcPredictions &ctc_predictions, int position) const {
  vector<Probability> result = static_transition_probabilities_;
  for (TransitionComputer *computer : transition_computers_) {
    auto pairs = computer->Compute(ctc_predictions, position);
    for (const pair<unsigned, Probability> &pair : pairs) {
      result[pair.first] = pair.second;
    }
  }
  return result;
}


Hmm::~Hmm() {
  for (State *state : states_) {
    delete state;
  }
  for (TransitionComputer *tc : transition_computers_) {
    delete tc;
  }
}

void Hmm::PrecomputeEpsilons() {
  int states_count = static_cast<int>(states_.size());
  int epsilon_states_count = static_cast<int>(epsilon_states_ids_.size());
  
  vector<bool> is_epsilon(states_count, false);
  vector<int> indegree(states_count, 0);
  for (unsigned state_id : epsilon_states_ids_) {
    is_epsilon[state_id] = true;
    for (const auto &arc : arcs_from_[state_id]) {
      indegree[arc.first] ++;
    }
  }
  queue<unsigned> q;
  for (unsigned state_id : epsilon_states_ids_) {
    if (indegree[state_id] == 0) {
      q.push(state_id);
    }
  }
  
  epsilon_states_ids_.clear();
  int processed = 0;
  while (!q.empty()) {
    unsigned cur = q.front();
    q.pop();
    processed++;
    epsilon_states_ids_.push_back(cur);

    for (const auto &arc : arcs_from_[cur]) {
      unsigned other = arc.first;
      indegree[other] --;
      if (indegree[other] == 0 && is_epsilon[other]) {
        q.push(other);
      }
    }
  }
  if (processed < epsilon_states_count) {
    fprintf(stderr, "Warning: an epsilon-loop in the HMM (results will be incorrect)\n");
  }
  epsilon_states_reverse_order_ = epsilon_states_ids_;
  reverse(epsilon_states_reverse_order_.begin(), epsilon_states_reverse_order_.end());
  epsilons_precomputed_ = true;
}

void Hmm::Finalize() {
  PrecomputeEpsilons();
}

vector<Probability> Hmm::GetStartingProbabilities() const {
  vector<Probability> result;
  for (State *state : states_) {
    result.push_back(state->GetStartingP());
  }
  return result;
}

vector<Probability> Hmm::GetEndingProbabilities() const {
  vector<Probability> result;
  for (State *state : states_) {
    result.push_back(state->GetEndingP());
  }
  return result;
}


vector<Probability> Hmm::Emission(const vector<Probability> &state_probabilities,
                                  const CtcPredictions &ctc_predictions,
                                  int position) const {
  vector<Probability> result(states_.size(), Probability::FromP(0));
  for (unsigned state_id : normal_states_ids_) {
    result[state_id] = state_probabilities[state_id] * states_[state_id]->EmissionProbability(ctc_predictions, position);
  }
  return result;
}

Hmm::TransitionSumResult Hmm::TransitionSum(const vector<Probability> &state_probabilities,
                                            const CtcPredictions &ctc_predictions,
                                            int position,
                                            Direction direction) const {
  assert(epsilons_precomputed_);
  TransitionSumResult result;
  result.new_state_probabilities_.resize(states_.size(), Probability::FromP(0));
  result.arc_probabilities_.resize(static_transition_probabilities_.size(), Probability::FromP(0));
  
  vector<Probability> transition_probabilities = TransitionProbabilities(ctc_predictions, position);
  const vector<vector<pair<unsigned, unsigned>>> &incident_arcs = (direction == FORWARD) ? arcs_from_ : arcs_to_;
  
  for (unsigned state_id : normal_states_ids_) {
    for (const auto &arc_pair : incident_arcs[state_id]) {
      unsigned other_state_id = arc_pair.first, arc_id = arc_pair.second;
      Probability p_arc = state_probabilities[state_id] * transition_probabilities[arc_id];
      result.arc_probabilities_[arc_id] = p_arc;
      result.new_state_probabilities_[other_state_id] += p_arc;
    }
  }
  const vector<unsigned> &epsilon_states = (direction == FORWARD) ? epsilon_states_ids_ : epsilon_states_reverse_order_;
  for (unsigned state_id : epsilon_states) {
    for (const auto &arc_pair : incident_arcs[state_id]) {
      unsigned other_state_id = arc_pair.first, arc_id = arc_pair.second;
      Probability p_arc = result.new_state_probabilities_[state_id] * transition_probabilities[arc_id];
      result.arc_probabilities_[arc_id] = p_arc;
      result.new_state_probabilities_[other_state_id] += p_arc;
    }
    result.new_state_probabilities_[state_id] = Probability::FromP(0);
  }
  return result;
}

Hmm::TransitionMaxResult Hmm::TransitionMax(const vector<Probability> &state_probabilities,
                                            const CtcPredictions &ctc_predictions,
                                            int position,
                                            Direction direction) const {
  assert(epsilons_precomputed_);
  TransitionMaxResult result;
  result.new_state_probabilities_.resize(states_.size(), Probability::FromP(0));
  result.come_from_.resize(states_.size(), kInvalidId);
  result.come_through_arc_.resize(states_.size(), kInvalidId);
  
  vector<Probability> transition_probabilities = TransitionProbabilities(ctc_predictions, position);
  const vector<vector<pair<unsigned, unsigned>>> &incident_arcs = (direction == FORWARD) ? arcs_from_ : arcs_to_;
  for (unsigned state_id : normal_states_ids_) {
    for (const auto &arc_pair : incident_arcs[state_id]) {
      unsigned other_state_id = arc_pair.first, arc_id = arc_pair.second;
      Probability p_arc = state_probabilities[state_id] * transition_probabilities[arc_id];
      if (p_arc >= result.new_state_probabilities_[other_state_id]) {
        result.new_state_probabilities_[other_state_id] = p_arc;
        result.come_from_[other_state_id] = state_id;
        result.come_through_arc_[other_state_id] = arc_id;
      }
    }
  }
  const vector<unsigned> &epsilon_states = (direction == FORWARD) ? epsilon_states_ids_ : epsilon_states_reverse_order_;
  
  for (unsigned state_id : epsilon_states) {
    for (const auto &arc_pair : incident_arcs[state_id]) {
      unsigned other_state_id = arc_pair.first, arc_id = arc_pair.second;
      Probability p_arc = result.new_state_probabilities_[state_id] * transition_probabilities[arc_id];
      if (p_arc >= result.new_state_probabilities_[other_state_id]) {
        result.new_state_probabilities_[other_state_id] = p_arc;
        result.come_from_[other_state_id] = state_id;
        result.come_through_arc_[other_state_id] = arc_id;
      }
    }
    result.new_state_probabilities_[state_id] = Probability::FromP(0);
  }
  return result;
}

unsigned Hmm::GetStatesCount() const {
  return states_.size();
}

unsigned Hmm::GetArcsCount() const {
  return static_transition_probabilities_.size();
}
