package kohonenSOM;

import java.util.Random;

public class SelfOrganisingMap
{
	protected int num_inputs;
	protected int width;
	protected int height;

	protected double[] inputs; //vstupy
	protected double[][] outputs; //vystupy
	protected double[][][] weights = null; //vahy

	protected double init_learning_rate = 0.1;
	protected double learning_rate = init_learning_rate;
	protected double learning_rate_decay = 0.99;

	protected double init_update_radius = -1;
	protected double update_radius = -1;
	protected double update_radius_decay = 0.999;

	protected int winner_x = 0; //vitaz
	protected int winner_y = 0;

	protected boolean winner_takes_all = true;
	protected boolean use_gaussian_update = false;
	protected boolean wrap_borders = false;

	public SelfOrganisingMap(int init_num_inputs, int init_width, int init_height)
	{
		num_inputs = init_num_inputs;
		width = init_width;
		height = init_height;
	}

	public void init()
	{
		inputs = new double[num_inputs];
		outputs = new double[width][height];
		weights = new double[width][height][num_inputs];

		if (init_update_radius < 0.0) init_update_radius = width/3;
		update_radius = init_update_radius;

		reset();
	}

	public void reset()
	{
		learning_rate = init_learning_rate;
		update_radius = init_update_radius;

		// Inputs are assumed to be in the range [0..1] so each of the weights
		// is initialised to a small (square) region around 0.5, i.e. [0.4..0.6].
		Random r = new Random();
		for (int y = 0; y < height; y++)
			for (int x = 0; x < width; x++)
				for (int w = 0; w < num_inputs; w++) {
					//weights[x][y][w] = r.nextInt(2);
					weights[x][y][w] = (0.2 * Math.random() - 0.1);
				}
	}

	public void train(double[] new_inputs)
	{
		activate(new_inputs);
		train();
	}

	public void test(double[] new_inputs)
	{
		activate(new_inputs);
	}

	public void activate(double[] new_inputs)
	{
		if (new_inputs.length != inputs.length)
			throw new IllegalArgumentException("Attempted to use incorrect number of inputs.");
		System.arraycopy(new_inputs, 0, inputs, 0, inputs.length);
		
		winner_x = 0;
		winner_y = 0;
		
		outputs[winner_x][winner_y] = calculateOutput(weights[winner_x][winner_y], inputs);
		double best = outputs[winner_x][winner_y];
		double worst = outputs[winner_x][winner_y];
		double sum = 0.0;

		for (int y = 0; y < height; y++)
		{
			for (int x = 0; x < width; x++)
			{
				outputs[x][y] = calculateOutput(weights[x][y], inputs);
				if (outputs[x][y] < best)
				{
					best = outputs[x][y];
					winner_x = x;
					winner_y = y;
				}
				else if (outputs[x][y] > worst)
				{
					worst = outputs[x][y];
				}
			}
		}

        if (winner_takes_all)
        {
            for (int y = 0; y < height; y++)
                for (int x = 0; x < width; x++)
                    outputs[x][y] = 0.0;
            outputs[winner_x][winner_y] = 1.0;
        }
    }

	public void train()
	{
		if (use_gaussian_update)
			updateGaussian();
		else
			updateRectangular();

		learning_rate = 0.1 + ((learning_rate-0.1)*learning_rate_decay);
		update_radius = 1.0 + ((update_radius-1.0)*update_radius_decay);
	}

	final private void updateGaussian()
	{
		int x1 = winner_x - width/2;
		int y1 = winner_y - height/2;
		int x2 = x1 + width-1;
		int y2 = y1 + height-1;

		if (!wrap_borders)
		{
			x1 = 0;
			y1 = 0;
			x2 = width-1;
			y2 = height-1;
		}

		int dx = 0;
		int dy = 0;
		double distance = 0.0;
		double gain = 0.0;

		int nx = 0;
		int ny = 0;
		double[] w = null;

		// Update weights of all neurons according to Gaussian function
		for (int y = y1; y <= y2; y++)
		{
			dy = y - winner_y;
			ny = (y + height) % height;

			for (int x = x1; x <= x2; x++)
			{
				dx = x - winner_x;
				nx = (x + width) % width;

				distance = Math.sqrt(dx*dx + dy*dy);
				gain = learning_rate * Math.exp(-distance / (2.0 * update_radius * update_radius));

				w = weights[nx][ny];
				for (int i = 0; i < num_inputs; i++)
					w[i] += gain * (inputs[i] - w[i]);
			}
		}
	}

	final private void updateRectangular()
	{
		int ur = (int) Math.round(update_radius);
		int x1 = winner_x - ur;
		int y1 = winner_y - ur;
		int x2 = winner_x + ur;
		int y2 = winner_y + ur;
		
		if (!wrap_borders)
		{
			x1 = Math.max(x1, 0);
			y1 = Math.max(y1, 0);
			x2 = Math.min(x2, width-1);
			y2 = Math.min(y2, height-1);
		}

		int dx = 0;
		int dy = 0;
		double distance = 0.0;

		int nx = 0;
		int ny = 0;
		double[] w = null;

		// Update weights of neurons within update region
		for (int y = y1; y <= y2; y++)
		{
			dy = y - winner_y;
			ny = (y + height) % height;

			for (int x = x1; x <= x2; x++)
			{
				dx = Math.min(Math.abs(x - winner_x), Math.abs((x+width) - winner_x));
				nx = (x + width) % width;

				distance = Math.sqrt(dx*dx + dy*dy);

				if (distance < update_radius)
				{
					w = weights[nx][ny];
					for (int i = 0; i < num_inputs; i++)
						w[i] += learning_rate * (inputs[i] - w[i]);
				}
			}
		}
	}

    final private double calculateOutput(double[] weights, double[] inputs)
    { return calculateEuclideanDistance(weights, inputs); }

	final private double calculateEuclideanDistance(double[] a, double[] b)
	{
		double diff;
		double sum = 0.0;
		for (int i = 0; i < a.length; i++)
		{
			diff = a[i]-b[i];
			sum += diff*diff;
		}
		return Math.sqrt(sum);
	}

	public int[] getWinner()
	{
		int[] winner = new int[2];
		winner[0] = winner_x;
		winner[1] = winner_y;
		return winner;
	}

	public int getNumInputs()
	{ return num_inputs; }
	
	public int getWidth()
	{ return width; }

	public int getHeight()
	{ return height; }

    public double[] getInputs()
    { return inputs; }

	public double[][] getOutputs()
	{ return getOutputs(new double[width][height]); }

	public double[][] getOutputs(double[][] dst_outputs)
	{
		if (dst_outputs == null) dst_outputs = new double[width][height];
        for (int i = 0; i < Math.min(dst_outputs.length, outputs.length); i++)
            System.arraycopy(outputs[i], 0, dst_outputs[i], 0, Math.min(dst_outputs[i].length, outputs[i].length));
        return dst_outputs;
	}

    public double[] getFlattenedOutputs()
    { return getFlattenedOutputs(new double[width * height]); }

    public double[] getFlattenedOutputs(double[] dst_outputs)
    {
		if (dst_outputs == null) dst_outputs = new double[width * height];

        int index = 0;
        for (int y = 0; y < height; y++)
            for (int x = 0; x < width; x++)
                dst_outputs[index++] = outputs[x][y];
        
        return dst_outputs;
    }

	public double getOutput(int x, int y)
	{ return outputs[x][y]; }

	public double[][][] getWeights()
	{ return getWeights(new double[width][height][num_inputs]); }

	public double[][][] getWeights(double[][][] dst_weights)
	{
		if (dst_weights == null) dst_weights = new double[width][height][num_inputs];
		for (int i = 0; i < Math.min(dst_weights.length, weights.length); i++)
			for (int j = 0; j < Math.min(dst_weights[i].length, weights[i].length); j++)
				System.arraycopy(weights[i][j], 0, dst_weights[i][j], 0, Math.min(dst_weights[i][j].length, weights[i][j].length));
		return dst_weights;
	}

    public double[][] getFlattenedWeights()
    { return getFlattenedWeights(new double[width * height][num_inputs]); }

    public double[][] getFlattenedWeights(double[][] dst_weights)
    {
		if (dst_weights == null) dst_weights = new double[width * height][num_inputs];

        int index = 0;
        for (int y = 0; y < height; y++)
            for (int x = 0; x < width; x++)
                System.arraycopy(weights[x][y], 0, dst_weights[index++], 0, num_inputs);
        
        return dst_weights;
    }

	public double[] getWeights(int x, int y)
	{ return getWeights(x, y, new double[num_inputs]); }

	public double[] getWeights(int x, int y, double[] dst_weights)
	{
		if (dst_weights == null) dst_weights = new double[num_inputs];
		System.arraycopy(weights[x][y], 0, dst_weights, 0, Math.min(dst_weights.length, num_inputs));
		return dst_weights;
	}

	public void setInitialLearningRate(double new_init_learning_rate)
	{ init_learning_rate = new_init_learning_rate; }

	public double getInitialLearningRate()
	{ return init_learning_rate; }

	public void setLearningRate(double new_learning_rate)
	{ learning_rate = new_learning_rate; }

	public double getLearningRate()
	{ return learning_rate; }

	public void setInitialUpdateRadius(double new_init_update_radius)
	{ init_update_radius = new_init_update_radius; }

	public double getInitialUpdateRadius()
	{ return init_update_radius; }

	public void setUpdateRadius(double new_update_radius)
	{ update_radius = new_update_radius; }

	public double getUpdateRadius()
	{ return update_radius; }

    public void setWinnerTakesAll(boolean new_winner_takes_all)
    { winner_takes_all = new_winner_takes_all; }

    public boolean isWinnerTakesAll()
    { return winner_takes_all; }

	public void setUseGaussianUpdate(boolean new_use_gaussian_update)
	{ use_gaussian_update = new_use_gaussian_update; }

    public boolean isUseGaussianUpdate()
    { return use_gaussian_update; }

	public void setWrapBorders(boolean new_wrap_borders)
	{ wrap_borders = new_wrap_borders; }

    public boolean isWrapBorders()
    { return wrap_borders; }
}