#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "model.h"
#include "simple_model.h"
#include "full_model.h"
#include "decoder.h"
#include "fb_decoder.h"
#include "viterbi_decoder.h"

namespace py = pybind11;

PYBIND11_MODULE(hornet_hmm, m) {
  m.doc() = "A C++ implementation of HMM";
  
  py::class_<Model>(m, "Model");
  
  py::class_<SimpleModel, Model>(m, "SimpleModel")
    .def(py::init<int, 
                  int, 
                  double, 
                  std::vector<double>,
                  double,
                  double,
                  double,
                  std::vector<std::vector<double>>>());
  
  py::class_<FullModel, Model>(m, "FullModel")
    .def(py::init<int,
                  int,
                  double,
                  double,
                  double, 
                  double, 
                  double, 
                  double, 
                  std::vector<double>, 
                  std::vector<std::vector<double>>>());
  
  py::class_<Decoder>(m, "Decoder")
    .def("Decode", &Decoder::Decode, "", 
         py::arg("ctc_predictions"),
         py::arg("label_priors"),
         py::arg("model"));
  
  py::class_<ForwardBackwardDecoder, Decoder>(m, "ForwardBackwardDecoder")
    .def(py::init<>());
  
  py::class_<ViterbiDecoder, Decoder>(m, "ViterbiDecoder")
    .def(py::init<>());

#ifdef VERSION_INFO
  m.attr("__version__") = VERSION_INFO;
#else
  m.attr("__version__") = "dev";
#endif
}
