#include <cmath>
#include <cstdio>
#include "fb_decoder.h"
using std::vector;
using std::min;

ForwardBackwardDecoder::ForwardBackwardDecoder() {}

vector<Probability> ForwardBackwardDecoder::DecodeInternal(const CtcPredictions &ctc_predictions,
                                                           const Model &model) {
  const Hmm *hmm = model.GetHmm();
  int length = static_cast<int>(ctc_predictions.predictions_.size());
  unsigned states_count = hmm->GetStatesCount();
  int big_step = static_cast<int>(sqrt(length));
  int windows_count = (length + big_step - 1) / big_step;
  
  vector<vector<Probability>> forward_snapshots(windows_count);
  
  vector<Probability> state_probabilities = hmm->GetStartingProbabilities();
  forward_snapshots[0] = state_probabilities;
  for (int i = 0; i+1 < windows_count; i++) {
    fprintf(stderr, "forward: %d / %d         \r", i*big_step, length);
    for (int position = i * big_step; position < i * big_step + big_step; position++) {
      state_probabilities = hmm->Emission(state_probabilities, ctc_predictions, position);
      auto transition = hmm->TransitionSum(state_probabilities, ctc_predictions, position+1, Hmm::FORWARD);
      state_probabilities = transition.new_state_probabilities_;
    }
    forward_snapshots[i+1] = state_probabilities;
  }
  
  vector<Probability> is_right(length, Probability::FromP(0));
  vector<Probability> is_left(length, Probability::FromP(0));
  vector<Probability> backward_state_probabilities = hmm->GetEndingProbabilities();
  for (int i = windows_count-1; i>=0; i--) {
    int start_position = i * big_step;
    int arc_start_position = start_position + 1;
    int end_position = min(length, i * big_step + big_step);
    int arc_end_position = min(length, end_position + 1);
    
    fprintf(stderr, "backward: %d / %d          \r", start_position, length);
    
    vector<vector<Probability>> forward_state_probabilities(end_position - start_position);
    vector<vector<Probability>> forward_arc_probabilities(arc_end_position - arc_start_position);
    vector<Probability> state_probabilities = forward_snapshots[i];
    for (int position = start_position; position < end_position; position++) {
      forward_state_probabilities[position - start_position] = state_probabilities;
      state_probabilities = hmm->Emission(state_probabilities, ctc_predictions, position);
      if (position + 1 < length) {
        auto transition = hmm->TransitionSum(state_probabilities, ctc_predictions, position+1, Hmm::FORWARD);
        state_probabilities = transition.new_state_probabilities_;
        forward_arc_probabilities[position + 1 - arc_start_position] = transition.arc_probabilities_;
      }
    }
    
    for (int position = end_position-1; position >= start_position; position--) {
      int arc_position = position+1;
      vector<Probability> backward_arc_probabilities;
      if (arc_position < length) {
        auto transition = hmm->TransitionSum(backward_state_probabilities, ctc_predictions, arc_position, Hmm::BACKWARD);
        backward_state_probabilities = transition.new_state_probabilities_;
        backward_arc_probabilities = transition.arc_probabilities_;
      }
      
      backward_state_probabilities = hmm->Emission(backward_state_probabilities, ctc_predictions, position);
      
      Probability sum = Probability::FromP(0);
      for (unsigned state_id = 0; state_id < states_count; state_id++) {
        Probability p_state = forward_state_probabilities[position-start_position][state_id] * backward_state_probabilities[state_id];
        sum += p_state;
      }
      
      for (unsigned state_id = 0; state_id < states_count; state_id++) {
        Probability p_state = forward_state_probabilities[position-start_position][state_id] * backward_state_probabilities[state_id];
        if (!model.IsBackground(state_id)) {
          is_right[position] += (p_state / sum);
        }
      }
      
      if (arc_position < length) {
        auto transition_probabilities = hmm->TransitionProbabilities(ctc_predictions, arc_position);
        for (unsigned arc_id = 0; arc_id < transition_probabilities.size(); arc_id++) {
          if (model.IsArcAdvancingLeft(arc_id) && transition_probabilities[arc_id] > Probability::FromP(0)) {
            Probability p_arc = forward_arc_probabilities[arc_position - arc_start_position][arc_id] * 
                                backward_arc_probabilities[arc_id] / transition_probabilities[arc_id];
            int left_position = arc_position - model.AdvancingArcPeriod(arc_id);
            if (left_position >= 0) {
              is_left[left_position] |= (p_arc / sum);
            }
          }
        }
      }
    }
  }
  
  vector<Probability> result(length);
  for (int position = 0; position < length; position++) {
    result[position] = (is_left[position] || is_right[position]);
  }
  
  return result;
}
