#include "viterbi_decoder.h"
#include <cstdio>
using std::vector;

ViterbiDecoder::ViterbiDecoder() {}

unsigned ViterbiDecoder::AddEntry(unsigned state_id) {
  PathFindingEntry entry;
  entry.ref_count_ = 1;
  entry.state_id_ = state_id;
  entry.arc_id_ = kInvalidId;
  entry.next_ = kInvalidId;
  
  unsigned id;
  if (!free_indices_.empty()) {
    id = free_indices_.back();
    path_finding_entries_[id] = entry;
    free_indices_.pop_back();
  }
  else {
    id = path_finding_entries_.size();
    path_finding_entries_.push_back(entry);
  }
  return id;
}

void ViterbiDecoder::Link(unsigned entry_from, unsigned entry_to, unsigned arc_id) {
  path_finding_entries_[entry_from].next_ = entry_to;
  path_finding_entries_[entry_from].arc_id_ = arc_id;
  path_finding_entries_[entry_to].ref_count_++;
}

void ViterbiDecoder::Remove(unsigned entry_id) {
  free_indices_.push_back(entry_id);
}

void ViterbiDecoder::DropReference(unsigned entry_id) {
  unsigned current = entry_id;
  path_finding_entries_[current].ref_count_ --;
  while (path_finding_entries_[current].ref_count_ == 0) {
    unsigned next = path_finding_entries_[current].next_;
    Remove(current);
    if (next == kInvalidId) {
      return;
    }
    current = next;
    path_finding_entries_[current].ref_count_ --;
  }
}


vector<Probability> ViterbiDecoder::DecodeInternal(const CtcPredictions &ctc_predictions, const Model &model) {
  const Hmm *hmm = model.GetHmm();
  int length = static_cast<int>(ctc_predictions.predictions_.size());
  unsigned states_count = hmm->GetStatesCount();
  
  vector<Probability> state_probabilities = hmm->GetEndingProbabilities();
  vector<unsigned> state_entries;
  for (int position = length - 1; position >= 0; position--) {
    if (position % 1000 == 0) {
      fprintf(stderr, "%d/%d                \r", position, length);
    }
    
    vector<unsigned> new_state_entries(states_count, kInvalidId);
    for (unsigned state_id = 0; state_id < states_count; state_id++) {
      new_state_entries[state_id] = AddEntry(state_id);
    }
    int arc_position = position + 1;
    if (arc_position < length) {
      auto transition = hmm->TransitionMax(state_probabilities, ctc_predictions, arc_position, Hmm::BACKWARD);
      state_probabilities = transition.new_state_probabilities_;
      for (unsigned state_id = 0; state_id < states_count; state_id++) {
        unsigned next_id = transition.come_from_[state_id];
        unsigned arc_id = transition.come_through_arc_[state_id];
        if (model.IsEpsilon(next_id)) {
          Link(new_state_entries[state_id], new_state_entries[next_id], arc_id);
        }
        else {
          Link(new_state_entries[state_id], state_entries[next_id], arc_id);
        }
      }
      for (unsigned entry : state_entries) {
        DropReference(entry);
      }
    }
    state_entries = new_state_entries;
    state_probabilities = hmm->Emission(state_probabilities, ctc_predictions, position);
  }
  
  unsigned best_first_state = std::numeric_limits<unsigned>::max();
  Probability best_p = Probability::FromP(0);
  
  vector<Probability> starting_probabilities = hmm->GetStartingProbabilities();
  for (unsigned state_id = 0; state_id < states_count; state_id++) {
    Probability p_state = state_probabilities[state_id] * starting_probabilities[state_id];
    if (p_state >= best_p) {
      best_p = p_state;
      best_first_state = state_id;
    }
  }
  
  vector<unsigned> states_sequence;
  vector<unsigned> arcs_sequence;
  for (unsigned entry = state_entries[best_first_state]; entry != kInvalidId; entry = path_finding_entries_[entry].next_) {
    states_sequence.push_back(path_finding_entries_[entry].state_id_);
    if (path_finding_entries_[entry].next_ != kInvalidId) {
      arcs_sequence.push_back(path_finding_entries_[entry].arc_id_);
    }
  }
  
  for (unsigned entry : state_entries) {
    DropReference(entry);
  }
  
  return StateSequenceToRepeats(states_sequence, arcs_sequence, length, model);
}
