#ifndef HMM_H
#define HMM_H
#include "probability.h"
#include "ctc_predictions.h"
#include <vector>
#include <algorithm>



class Hmm {
public:
  enum Direction {FORWARD = 0, BACKWARD = 1};
private:
  const unsigned kInvalidId = std::numeric_limits<unsigned>::max();
  
  class State {
    Probability starting_p_;
    Probability ending_p_;
  public:
    State(Probability starting_p,
          Probability ending_p);
    virtual Probability EmissionProbability(const CtcPredictions &ctc_predictions, int position) = 0;
    virtual ~State();
    Probability GetStartingP();
    Probability GetEndingP();
  };
  
  class EmissionState : public State {
    std::vector<Probability> emission_distribution_;
  public:
    EmissionState(Probability starting_p,
                  Probability ending_p,
                  std::vector<Probability> distribution);
    Probability EmissionProbability(const CtcPredictions &ctc_predictions, int position) override;         
  };
  
  class SimpleEmissionState : public State {
    unsigned emitted_label_;
  public:
    SimpleEmissionState(Probability starting_p, 
                        Probability ending_p, 
                        unsigned emitted_label);
    Probability EmissionProbability(const CtcPredictions &ctc_predictions, int position) override;
  };
  
  class EpsilonState : public State {
  public:
    EpsilonState(Probability starting_p,
                 Probability ending_p);
    Probability EmissionProbability(const CtcPredictions &ctc_predictions, int position) override;                            
  };
  
  class MatchingNucleotideState : public State {
    std::vector<std::vector<Probability>> emission_matrix_;
    int period_;
  public:
    MatchingNucleotideState(Probability starting_p,
                            Probability ending_p,
                            std::vector<std::vector<Probability>> emission_matrix,
                            int period);
    Probability EmissionProbability(const CtcPredictions &ctc_predictions, int position) override;
  };
  
  class TransitionComputer {
  public:
    virtual std::vector<std::pair<unsigned, Probability>> Compute(const CtcPredictions &ctc_predictions,
                                                                  int position) = 0;
    virtual ~TransitionComputer();
  };

  class LeftDependentTransitions : public TransitionComputer {
    std::vector<unsigned> arc_ids_;
    std::vector<std::vector<std::pair<unsigned, Probability>>> label_contributions_;
    std::vector<Probability> label_masking_weights_;
    int period_;
  public:
    LeftDependentTransitions(std::vector<unsigned> arc_ids,
                             std::vector<std::vector<std::pair<unsigned, Probability>>> label_contributions,
                             std::vector<Probability> label_masking_weights,
                             int period);
    std::vector<std::pair<unsigned, Probability>> Compute(const CtcPredictions &ctc_predictions,
                                                          int position) override;
  };
  
  class RepeatEnteringTransitions : public TransitionComputer {
    std::vector<unsigned> arc_ids_;
    std::vector<Probability> default_probabilities_;
    std::vector<int> periods_;
  public:
    RepeatEnteringTransitions(std::vector<unsigned> arc_ids,
                              std::vector<Probability> default_probabilities,
                              std::vector<int> periods);
    std::vector<std::pair<unsigned, Probability>> Compute(const CtcPredictions &ctc_predictions,
                                                          int position) override;
  };
  
  std::vector<State*> states_;
  std::vector<TransitionComputer*> transition_computers_;
  std::vector<Probability> static_transition_probabilities_;

  std::vector<std::vector<std::pair<unsigned, unsigned>>> arcs_from_, arcs_to_;
  bool epsilons_precomputed_;
  std::vector<unsigned> epsilon_states_ids_;
  std::vector<unsigned> epsilon_states_reverse_order_;
  std::vector<unsigned> normal_states_ids_;
  
  void PrecomputeEpsilons();
  unsigned AddState(State *state);
  unsigned AddArc(unsigned from, unsigned to);
  std::vector<unsigned> AddArcs(unsigned from, const std::vector<unsigned> &to);
  
public:
  Hmm();
  ~Hmm();
  
  unsigned AddEmissionState(std::vector<Probability> emission_distribution,
                            Probability starting_p = Probability::FromP(0), 
                            Probability ending_p = Probability::FromP(1));
  unsigned AddSimpleEmissionState(unsigned emitted_label, 
                                  Probability starting_p = Probability::FromP(0), 
                                  Probability ending_p = Probability::FromP(1));
  unsigned AddEpsilonState(Probability starting_p = Probability::FromP(0), 
                           Probability ending_p = Probability::FromP(0));
  unsigned AddMatchState(std::vector<std::vector<Probability>> emission_matrix,
                         int period,
                         Probability starting_p = Probability::FromP(0),
                         Probability ending_p = Probability::FromP(1));
  
  void Finalize();
  
  unsigned GetStatesCount() const;
  unsigned GetArcsCount() const;
  
  std::vector<unsigned> AddUnconditionalArcs(unsigned from_state, 
                                             std::vector<unsigned> to_states, 
                                             std::vector<Probability> p_transitions);
  std::vector<unsigned> AddLeftDependentArcs(unsigned from_state,
                                             std::vector<unsigned> to_states,
                                             std::vector<std::vector<std::pair<unsigned, Probability>>> label_contributions,
                                             std::vector<Probability> label_masking_weights,
                                             int period);  
  std::vector<unsigned> AddRepeatEnteringArcs(unsigned from_state,
                                              std::vector<unsigned> to_states,
                                              std::vector<Probability> default_probabilities,
                                              std::vector<int> periods);
  
  std::vector<Probability> GetStartingProbabilities() const;
  std::vector<Probability> GetEndingProbabilities() const;  
  
  std::vector<Probability> TransitionProbabilities(const CtcPredictions &ctc_predictions,
                                                   int position) const;
  
  std::vector<Probability> Emission(const std::vector<Probability> &state_probabilities,
                                    const CtcPredictions &ctc_predictions,
                                    int position) const;
  
  struct TransitionSumResult {
    std::vector<Probability> new_state_probabilities_;
    std::vector<Probability> arc_probabilities_;
  };
                                    
  TransitionSumResult TransitionSum(const std::vector<Probability> &state_probabilities,
                                    const CtcPredictions &ctc_predictions,
                                    int position,
                                    Direction direction) const;
  
  struct TransitionMaxResult {
    std::vector<Probability> new_state_probabilities_;
    std::vector<unsigned> come_from_;
    std::vector<unsigned> come_through_arc_;
  };
  
  TransitionMaxResult TransitionMax(const std::vector<Probability> &state_probabilities,
                                    const CtcPredictions &ctc_predictions,
                                    int position,
                                    Direction direction) const;
};

#endif
