#include "matrix.h"

#define EXP_NO "sada1.h"
#define NETTYPE "SRN"
#define NETPREFIX "SRN"


#define transTime 3
#define alpha 0.01

#define inputSets 20
#define testSets 20
#define EPOCH_COUNT 50

#define NET_FILENAME "x1-reset-20x"

//chceme resetovat kontext?
#define DO_RESET


#ifdef DO_RESET
#define RESET_STATEMENT resetState(n);
#else
#pragma warning (disable: 4390)
#define RESET_STATEMENT ;
#endif


#include "input.h"

#define dimI dimInput
#define dimJ 10
#define dimK dimOutput

using namespace std;


struct net {
	Matrix v;
	Matrix m;
	Matrix w;
	Vector netHidden;
	Vector netOut;
	Vector y;
	Vector x;
	Vector x2;
	Vector y2;
	Vector o;
	Vector dNetHidden;
	Vector dNetOut;
	Vector delta;

	Matrix partV[dimJ];
	Matrix partM[dimJ];
	Matrix oldPartV[dimJ];
	Matrix oldPartM[dimJ];
};	


Vector vecBias(1);

inline float rand2(){
	return 2*((float)rand())/RAND_MAX -1;
}


void resetState( net& n ){
	//zresetujeme stav
	int j;
	for( j=0; j<dimJ; j++)
		n.y.val[j] = 0;
}


void init_w( Matrix& m ){
	int j,l;
	for( j=0; j<dimJ; j++ )
		for( l=0; l<dimJ; l++ ){
			m[j][l] = rand2();
		}
}

void init_wIn( Matrix& m ){
	int j,i;
	for( j=0; j<dimJ; j++ )
		for( i=0; i<dimI+1; i++ )
			m.val[j][i] = (float)(rand()%2) - 0.5;
}

void init_wOut( Matrix& m ){
	int j,k;
	for( k=0; k<dimK; k++ )
		for( j=0; j<dimJ+1; j++ )
			m.val[k][j] = rand2();
}

void initNet( net& n ){
	int i,j,k,l,h;

	//vstupne vahy
	n.v = Matrix( dimJ, dimI+1 );
	init_wIn(n.v);
	//vnutorne vahy
	n.m = Matrix( dimJ, dimJ );
	init_w(n.m);
	//vystupne vahy
	n.w = Matrix( dimK, dimJ+1 );
	init_wOut( n.w );
	//pociatocny stav
	n.y = Vector(dimJ);
	resetState(n);

	//alokuj deltu
	n.delta = Vector(dimK);
	//pomocny vektor pre prah
	vecBias = Vector(1);
	vecBias.val[0] = -1;

	n.y2 = (n.y,vecBias);

	for( h=0; h<dimJ; h++){
		n.partV[h] = Matrix( dimJ, dimI+1);
		n.partM[h] = Matrix( dimJ, dimJ );
		for(j=0; j<dimJ; j++){
			for(i=0; i<dimI+1; i++)
				n.partV[h].val[j][i] = 0;
			for(i=0; i<dimJ; i++)
				n.partM[h].val[j][i] = 0;
		}
	}
}




//vyzaduje novy vstup	n.x2
//a predosly stav		n.y
Vector activate( net& n ){
	Vector p1 = n.v*n.x2;
	Vector p2 = n.m*n.y;

	Vector res(dimJ);
	int j;
	for( j=0; j<dimJ; j++ )
		res.val[j] = p1.val[j] + p2.val[j];
	return res;
}


#include "netReadWrite.h"




//vyzaduje slabiku v	inputStr
//po skonceni je v	n.y a n.y2 novy stav
//a v				n.o vystup siete
Vector getActivities( char* inputStr, net& n ){
	//spocitat novy vnutorny stav
	n.x = getSyllabe( inputStr );
	n.x2 = (n.x,vecBias);
	//aktivujeme siet
	n.netHidden = activate( n );
	n.y = Sigmoid( n.netHidden );
	n.y2 = (n.y,vecBias);
	//vypocitat vystup
	n.netOut = n.w*n.y2;
	n.o = softMax(n.netOut);
	return n.o;
}



#include "testWord.h"

void train( net& n ){
	initInput();
	resetState(n);


	Vector oldY(0);
	Vector oldY2(0);
	Vector d(0);
	int t,i,j,k,l,h;

	int trainMax = (inputLen-1)*WORDSEG;
	int progressUnit = trainMax/20;
	for( i=0; i<trainMax/progressUnit; i++ )
		cout<<"*";
	cout<<endl;

	//pociatocny prechod
	for( t=0; t<transTime*WORDSEG; t++ ){
		getActivities( getInputStr(t), n );
	}

	//trenovanie
	for( t=0; t<trainMax; t++ ){
		//resetovanie stavu, ak sme na hranici slov
		if( t % WORDSEG == 0)
			RESET_STATEMENT
		//progres
		if( t%progressUnit == (progressUnit-1) )
			cout<<"*"<<flush;

		oldY = n.y;
		oldY2 = (n.y,vecBias);


		getActivities( getInputStr(t), n );
		d = getOutput(t);
		n.dNetHidden = DerivSigm( n.y );


		for(h=0; h<dimJ; h++){
			n.oldPartV[h] = n.partV[h];
			n.oldPartM[h] = n.partM[h];
		}



		//delta ok
		for( k=0; k<dimK; k++ ){
			n.delta.val[k] = (d.val[k] - n.o.val[k]);
		}


		//zmena v
		for(j=0; j<dimJ; j++)
			for(i=0; i<dimI+1; i++){

				NUM sumK = 0;
				for( k=0; k<dimK; k++ ){
					NUM sumJ = 0;
					for( h=0; h<dimJ; h++ ){
						//dopocitat parc. yh/vji
						NUM sumL = 0;
						for(l=0; l<dimJ; l++)
							sumL += n.m[h][l] * n.oldPartV[l].val[j][i];
						if(h==j)
							sumL += n.x2.val[i];
						n.partV[h].val[j][i] = n.dNetHidden.val[h] * sumL;
						//end dopocitat parc
						sumJ += n.w[k][h] * n.partV[h].val[j][i];
					}
					sumK += n.delta.val[k] * sumJ;
				}

				n.v.val[j][i] = n.v.val[j][i] + alpha * sumK;
			}


		//zmena m
		for(j=0; j<dimJ; j++)
			for(i=0; i<dimJ; i++){

				NUM sumK = 0;
				for( k=0; k<dimK; k++ ){
					NUM sumJ = 0;
					for( h=0; h<dimJ; h++ ){
						//dopocitat parc. yh/vji
						NUM sumL = 0;
						for(l=0; l<dimJ; l++)
							sumL += n.m[h][l] * n.oldPartM[l].val[j][i];
						if(h==j)
							sumL += oldY2.val[i];
						n.partM[h].val[j][i] = n.dNetHidden.val[h] * sumL;
						//end dopocitat parc
						sumJ += n.w[k][h] * n.partM[h].val[j][i];
					}
					sumK += n.delta.val[k] * sumJ;
				}

				n.m.val[j][i] = n.m.val[j][i] + alpha * sumK;
			}

		//delta w
		for( k=0; k<dimK; k++ )
			for( j=0; j<dimJ+1; j++){
				n.w.val[k][j] = n.w.val[k][j] + alpha*n.delta.val[k]*n.y2.val[j];
			}


	}
	cout<<endl;
}





void test(net& n, int testSetCount = testSets){
	int wordNum;
	int i,t,k;
	Vector d(0);

	//premenne na sledovanie chyb
	int errCount = 0;
	int errCountMod[WORDSEG];
	int countMod[WORDSEG];
	//vynulujeme polia
	for( i=0; i<WORDSEG; i++ ){
		errCountMod[i] = 0;
		countMod[i] = 0;
	}
	//abs. chyba
	float errAbs = 0;

	int testLen = testSetCount * WORDCOUNT;
	initInput( testLen );

	for( t=0; t<testLen*WORDSEG; t++ ){
		//resetovanie stavu, ak sme na hranici slov
		if( t % WORDSEG == 0)
			RESET_STATEMENT
	
		getActivities( getInputStr(t), n );
		d = getOutput(t);

		wordNum = applyDiscriminatorNum( applyOneHotFilter(n.o) );
		if(wordNum!=-1){
			if(!(codeWords[wordNum]==d)){
				errCountMod[t%WORDSEG]++;					
				errCount++;
			}
			countMod[t%WORDSEG]++;
		}

		errAbs += d.dist( n.o );
	}
	cout<<"Chyby:";
	for( k=0; k<WORDSEG; k++)
		cout<<"  "<<errCountMod[k]<<"/"<<countMod[k];
	cout<<endl;
	cout<<"Abs. chyba: "<<(errAbs/testLen)<<endl;
}



int main( int argc, char *argv[] ){
	unsigned int seed =(unsigned)time( NULL );
	srand(seed);

	net n;
	int i;
	char buf[WORDSIZE+1];

	initNet(n);
	initCodeWords();


	//TESTOVANIE NATRENOVANEJ SIETE
	if(argc>1){
		ReadNet( n, argv[1] );
		
		#include "test.h"

		return 0;
	}


	//TRENOVANIE
	int epoch;
	char net_fileName[100];
	for( epoch=0; epoch<EPOCH_COUNT; epoch++ ){
		cout<<"epocha "<<epoch<<endl;
		train(n);

		
		Vector tmp = n.y;
		test(n,1);

		for( i=0; i<WORDCOUNT && i<1000; i++){
			RESET_STATEMENT
			testWord( arrWords[i], n);
		}
		cout<<"----------------------------"<<endl;
		#include "test.h"

		n.y = tmp;
		//zapis siet
		if( ((epoch+1)==2) || ((epoch+1)==5) || ((epoch+1)==10) || ((epoch+1)==20) || ((epoch+1)==50) ){
			sprintf( net_fileName, "%s%d.net", NET_FILENAME, (epoch+1) );
			cout<<"Zapisujem subor "<<net_fileName<<endl;
			WriteNet( n, net_fileName );
		}
	}
	sprintf( net_fileName, "%s%d.net", NET_FILENAME, epoch );
	cout<<"Zapisujem subor "<<net_fileName<<endl;
	WriteNet( n, net_fileName );


	cout<<"trenovanie ukoncene"<<endl;
	cout<<endl;

//TESTOVANIE
	test(n);

	cout<<endl;
	cout<<"-------------------------------"<<endl;
	cout<<"pocet epoch: "<<EPOCH_COUNT<<endl;
	cout<<"inputLen "<<inputLen<<endl;
	cout<<"EXP_NO "<<EXP_NO<<endl;
	cout<<"pocet skrytych neuronov "<<dimJ<<endl;
	cout<<"\t <random seed "<<seed<<">"<<endl;

	delete []inputSeq;

	return 0;
}

