package org.sunflow.raytracer;

import org.sunflow.image.Color;
import org.sunflow.math.Halton;
import org.sunflow.math.OrthoNormalBasis;
import org.sunflow.math.Point3;
import org.sunflow.math.QMCSequence;
import org.sunflow.math.Vector3;
import org.sunflow.raytracer.accel.UniformGrid;
import org.sunflow.raytracer.photonmap.CausticPhotonMap;
import org.sunflow.raytracer.photonmap.GlobalPhotonMap;
import org.sunflow.system.ProgressDisplay;
import java.util.ArrayList;

class LightServer {
    private ArrayList lightList;
    private IntersectionAccelerator intAccel;
    private LightSource[] lights;
    private RenderOptions options;

    // indirect illumination
    private GlobalPhotonMap globalPhotonMap;
    private CausticPhotonMap causticPhotonMap;
    private IrradianceCache irradianceCache;
    private int irrM;
    private int irrN;

    // direct illumination classification
    private static final int STATUS_SHADOWED = 0;
    private static final int STATUS_PENUMBRA = 1;
    private static final int STATUS_FULL_ILLUM = 2;

    LightServer() {
        lightList = new ArrayList();
        lights = null;
        intAccel = new UniformGrid();
        globalPhotonMap = null;
        causticPhotonMap = null;
        options = null;
    }

    void registerObject(Intersectable object) {
        intAccel.add(object);
        object.getSurfaceShader().setLightServer(this);
    }

    void registerLight(LightSource light) {
        lightList.add(light);
    }

    void build(RenderOptions options, ProgressDisplay output) {
        this.options = options;
        intAccel.build(output);
        if (output.isCanceled())
            return;
        long startTime;
        long endTime;
        startTime = System.currentTimeMillis();
        output.println("[LSV] Light Server Init");
        lights = (LightSource[]) lightList.toArray(new LightSource[lightList.size()]);
        int numEmitted = 0;
        if (options.computeGI() || options.computeCaustics()) {
            globalPhotonMap = options.computeGI() ? new GlobalPhotonMap(options.getNumPhotons(), options.getNumGather(), 1e10) : null;
            causticPhotonMap = options.computeCaustics() ? new CausticPhotonMap(options.getNumPhotons(), options.getNumGather(), 1e10, 1.1) : null;
            if (lights.length > 0) {
                double[] lightPowers = new double[lights.length];
                for (int i = 0; i < lights.length; i++) {
                    lightPowers[i] = lights[i].getAveragePower();
                    if (i > 0)
                        lightPowers[i] += lightPowers[i - 1];
                }
                output.println("[LSV] Tracing photons ...");
                if (options.computeGI())
                    output.setTask("Photon Tracing", 0, globalPhotonMap.maxSize());
                else
                    output.setTask("Photon Tracing", 0, causticPhotonMap.maxSize());
                QMCSequence photonSampler = new Halton(2, 4);
                while (options.computeGI() ? (!globalPhotonMap.isFull()) : (!causticPhotonMap.isFull())) {
                    if (options.computeGI())
                        output.update(globalPhotonMap.size());
                    else
                        output.update(causticPhotonMap.size());
                    double[] rnd = photonSampler.getNext();
                    double rand1x = rnd[0];
                    double rand = rand1x * lightPowers[lights.length - 1];
                    int light = -1;
                    for (int i = 0; i < lights.length; i++) {
                        if (rand < lightPowers[i]) {
                            light = i;
                            if (i == 0)
                                rand1x = rand / lightPowers[0];
                            else
                                rand1x = (rand - lightPowers[i - 1]) / (lightPowers[i] - lightPowers[i - 1]);
                            break;
                        }
                    }
                    if (light >= 0) {
                        Point3 pt = new Point3();
                        Vector3 dir = new Vector3();
                        Color power = new Color();
                        lights[light].getPhoton(rand1x, rnd[1], rnd[2], rnd[3], pt, dir, power);
                        tracePhoton(RenderState.createPhotonState(new Ray(pt, dir)), power);
                        numEmitted++;
                    }
                    if (output.isCanceled())
                        return;
                }
                if (options.computeGI())
                    output.update(globalPhotonMap.size());
                else
                    output.update(causticPhotonMap.size());
                if (options.computeGI()) {
                    output.setTask("Balancing global photon map", 0, 1);
                    output.println("[LSV] Balancing global photon map ...");
                    if (options.computeCaustics())
                        globalPhotonMap.initialize(1.0 / numEmitted / (1.0 - options.getPhotonReductionRatio()));
                    else
                        globalPhotonMap.initialize(1.0 / numEmitted);
                    if (output.isCanceled())
                        return;
                    output.setTask("Precomputing irradiance", 0, 1);
                    output.println("[LSV] Precomputing irradiance ...");
                    globalPhotonMap.precomputeIrradiance(true, true);
                    if (output.isCanceled())
                        return;
                }
                if (options.computeCaustics()) {
                    output.setTask("Balancing caustic photon map", 0, 1);
                    output.println("[LSV] Balancing caustic photon map ...");
                    causticPhotonMap.initialize(1.0 / numEmitted);
                    if (output.isCanceled())
                        return;
                }
            }
            irradianceCache = options.irradianceCaching() ? new IrradianceCache(options.getIrradianceCacheTolerance(), options.getIrradianceCacheSpacing(), intAccel.getBounds()) : null;
            irrM = (int) Math.max(1, Math.sqrt(options.getIrradianceSamples() / Math.PI));
            irrN = (int) Math.max(1, (irrM * Math.PI));
        }
        endTime = System.currentTimeMillis();
        output.println("[LSV] Light Server Statistics:");
        output.println("[LSV]   * Light sources found: " + lights.length);
        output.println("[LSV]   * Shadows:             " + (options.traceShadows() ? "on" : "off"));
        output.println("[LSV]   * Light samples:       " + options.getNumLightSamples());
        output.println("[LSV]   * Max raytrace depth:  " + options.getMaxDepth());
        output.println("[LSV]   * Emitted photons:     " + numEmitted);
        output.println("[LSV]   * Global photons:      " + (options.computeGI() ? globalPhotonMap.size() : 0));
        output.println("[LSV]   * Caustic photons:     " + (options.computeCaustics() ? causticPhotonMap.size() : 0));
        output.println("[LSV]   * Irr. cache sampling: " + (options.computeGI() ? ("" + irrM + "x" + irrN) : "0x0"));
        output.println("[LSV]   * Build time:          " + ((endTime - startTime) / 1000.0) + " secs.");
        output.println("[LSV] Done.");
    }

    void display(Camera cam) {
        if (options.computeGI())
            globalPhotonMap.display(cam, "gphotons.hdr");
        if (options.computeCaustics())
            causticPhotonMap.display(cam, "cphotons.hdr");
    }

    void storePhoton(RenderState state, Vector3 dir, Color power) {
        boolean isCaustic = (state.getDiffuseDepth() == 0) && (state.getSpecularDepth() > 0);
        if (options.computeGI() && (!options.computeCaustics() || (Math.random() >= options.getPhotonReductionRatio())))
            globalPhotonMap.storePhoton(state, dir, power, state.getDepth() == 0, isCaustic);
        if (options.computeCaustics() && isCaustic)
            causticPhotonMap.storePhoton(state, dir, power);
    }

    private void tracePhoton(RenderState state, Color power) {
        if (state.getDepth() >= options.getMaxDepth())
            return;
        intAccel.intersect(state);
        if (state.hit()) {
            state.getObject().setSurfaceLocation(state);
            state.getObject().getSurfaceShader().scatterPhoton(state, power);
        }
    }

    void traceDiffusePhoton(RenderState previous, Ray r, Color power) {
        if (!options.computeGI())
            return;
        RenderState state = RenderState.createDiffuseBounceState(previous, r);
        tracePhoton(state, power);
    }

    void traceSpecularPhoton(RenderState previous, Ray r, Color power) {
        RenderState state = RenderState.createSpecularBounceState(previous, r);
        tracePhoton(state, power);
    }

    Color getRadiance(Ray r) {
        return getRadiance(RenderState.createState(r));
    }

    private Color getRadiance(RenderState state) {
        if (state.getDepth() >= options.getMaxDepth())
            return new Color(Color.BLACK);
        intAccel.intersect(state);
        if (state.hit()) {
            state.getObject().setSurfaceLocation(state);
            return state.getObject().getSurfaceShader().getRadiance(state);
        } else
            return new Color(Color.BLACK);
    }

    Color traceDiffuse(RenderState previous, Ray r) {
        return getRadiance(RenderState.createDiffuseBounceState(previous, r));
    }

    Color traceSpecular(RenderState previous, Ray r) {
        return getRadiance(RenderState.createSpecularBounceState(previous, r));
    }

    Color getIrradiance(RenderState state) {
        if (!options.computeGI())
            return new Color(Color.BLACK);
        if (irradianceCache == null || state.getDiffuseDepth() > 0)
            return globalPhotonMap.getIrradiance(state.getVertex().p, state.getVertex().n);
        Color irr = irradianceCache.getIrradiance(state.getVertex().p, state.getVertex().n);
        if (irr == null) {
            // compute new sample
            irr = new Color(Color.BLACK);
            OrthoNormalBasis onb = OrthoNormalBasis.makeFromW(state.getVertex().n);
            int hits = 0;
            double invR = 0.0;
            Vector3 w = new Vector3();

            // irradiance gradients
            Color[] rotGradient = new Color[3];
            Color[] transGradient1 = new Color[3];
            Color[] transGradient2 = new Color[3];
            for (int i = 0; i < 3; i++) {
                rotGradient[i] = new Color(Color.BLACK);
                transGradient1[i] = new Color(Color.BLACK);
                transGradient2[i] = new Color(Color.BLACK);
            }

            // irradiance gradients temp variables
            Vector3 vi = new Vector3();
            Color rotGradientTemp = new Color();
            Vector3 ui = new Vector3();
            Vector3 vim = new Vector3();
            Color transGradient1Temp = new Color();
            Color transGradient2Temp = new Color();
            Color lijm = new Color(); // L_i,j-1
            Color[] lim = new Color[irrM]; // L_i-1,j
            Color[] l0 = new Color[irrM]; // L_0,j
            for (int i = 0; i < irrM; i++) {
                lim[i] = new Color();
                l0[i] = new Color();
            }
            double rijm = 0; // R_i,j-1
            double[] rim = new double[irrM]; // R_i-1,j
            double[] r0 = new double[irrM]; // R_0, j
            for (int i = 0; i < irrN; i++) {
                double xi = (i + Math.random()) / irrN;
                double phi = 2 * Math.PI * xi;
                double cosPhi = Math.cos(phi);
                double sinPhi = Math.sin(phi);
                vi.x = -sinPhi; //Math.cos(phi + Math.PI * 0.5);
                vi.y = cosPhi; //Math.sin(phi + Math.PI * 0.5);
                vi.z = 0.0;
                onb.transform(vi);
                rotGradientTemp.set(Color.BLACK);
                ui.x = cosPhi;
                ui.y = sinPhi;
                ui.z = 0.0;
                onb.transform(ui);
                double phim = (2.0 * Math.PI * i) / irrN;
                vim.x = Math.cos(phim + (Math.PI * 0.5));
                vim.y = Math.sin(phim + (Math.PI * 0.5));
                vim.z = 0.0;
                onb.transform(vim);
                transGradient1Temp.set(Color.BLACK);
                transGradient2Temp.set(Color.BLACK);
                for (int j = 0; j < irrM; j++) {
                    double xj = (j + Math.random()) / irrM;
                    double sinTheta = Math.sqrt(xj);
                    double cosTheta = Math.sqrt(1.0 - xj);
                    w.x = cosPhi * sinTheta;
                    w.y = sinPhi * sinTheta;
                    w.z = cosTheta;
                    onb.transform(w);
                    Color lij = Color.BLACK;
                    RenderState temp = RenderState.createFinalGatherState(state, new Ray(state.getVertex().p, w));
                    intAccel.intersect(temp);
                    if (temp.hit()) {
                        invR += (1.0 / temp.getT());
                        hits++;
                        temp.getObject().setSurfaceLocation(temp);
                        lij = temp.getObject().getSurfaceShader().getRadiance(temp);
                        irr.add(lij);
                        // increment rotational gradient
                        rotGradientTemp.madd(-sinTheta / cosTheta, lij);
                    }

                    // increment translational gradient
                    double rij = temp.getT();
                    double sinThetam = Math.sqrt((double) j / irrM);
                    if (j > 0) {
                        double k = (sinThetam * (1.0 - ((double) j / irrM))) / Math.min(rij, rijm);
                        transGradient1Temp.add(Color.sub(lij, lijm).mul(k));
                    }
                    if (i > 0) {
                        double sinThetap = Math.sqrt((double) (j + 1) / irrM);
                        double k = (sinThetap - sinThetam) / Math.min(rij, rim[j]);
                        transGradient2Temp.add(Color.sub(lij, lim[j]).mul(k));
                    } else {
                        r0[j] = rij;
                        l0[j].set(lij);
                    }

                    // set previous
                    rijm = rij;
                    lijm.set(lij);
                    rim[j] = rij;
                    lim[j].set(lij);
                }

                // increment rotational gradient vector
                rotGradient[0].madd(vi.x, rotGradientTemp);
                rotGradient[1].madd(vi.y, rotGradientTemp);
                rotGradient[2].madd(vi.z, rotGradientTemp);
                // increment translational gradient vectors
                transGradient1[0].madd(ui.x, transGradient1Temp);
                transGradient1[1].madd(ui.y, transGradient1Temp);
                transGradient1[2].madd(ui.z, transGradient1Temp);
                transGradient2[0].madd(vim.x, transGradient2Temp);
                transGradient2[1].madd(vim.y, transGradient2Temp);
                transGradient2[2].madd(vim.z, transGradient2Temp);
            }

            // finish computing second part of the translational gradient
            vim.x = 0.0;
            vim.y = 1.0;
            vim.z = 0.0;
            onb.transform(vim);
            transGradient2Temp.set(Color.BLACK);
            for (int j = 0; j < irrM; j++) {
                double sinThetam = Math.sqrt((double) j / irrM);
                double sinThetap = Math.sqrt((double) (j + 1) / irrM);
                double k = (sinThetap - sinThetam) / Math.min(r0[j], rim[j]);
                transGradient2Temp.add(Color.sub(l0[j], lim[j]).mul(k));
            }
            transGradient2[0].madd(vim.x, transGradient2Temp);
            transGradient2[1].madd(vim.y, transGradient2Temp);
            transGradient2[2].madd(vim.z, transGradient2Temp);
            // scale first part of translational gradient
            double scale = (2.0 * Math.PI) / irrN;
            transGradient1[0].mul(scale);
            transGradient1[1].mul(scale);
            transGradient1[2].mul(scale);
            // sum two pieces of translational gradient
            transGradient1[0].add(transGradient2[0]);
            transGradient1[1].add(transGradient2[1]);
            transGradient1[2].add(transGradient2[2]);
            scale = Math.PI / (irrM * irrN);
            irr.mul(scale);
            rotGradient[0].mul(scale);
            rotGradient[1].mul(scale);
            rotGradient[2].mul(scale);
            invR = (hits > 0) ? (hits / invR) : 0.0;
            irradianceCache.insert(state.getVertex().p, state.getVertex().n, invR, irr, rotGradient, transGradient1);
            if (options.displayIrradianceSamples())
                return new Color(Color.YELLOW).mul(1e6);
        }
        return irr;
    }

    void initLightSamples(RenderState state, boolean getCaustics, boolean getIndirectDiffuse) {
        if (getIndirectDiffuse && (state.getDiffuseDepth() > 0))
            return;
        int max = options.getNumLightSamples();
        if (options.computeCaustics() && getCaustics)
            max += options.getNumGather();
        state.initSamples(max);
        if (options.getNumLightSamples() > 0) {
            QMCSequence sampler = new Halton(2, 2);
            int numVisibleLights = 0;
            int[] shadowStatus = new int[lights.length];
            int[] visibleLights = new int[lights.length];
            for (int i = 0; i < lights.length; i++) {
                shadowStatus[i] = lights[i].isVisible(state) ? STATUS_PENUMBRA : STATUS_SHADOWED;
                if (shadowStatus[i] != STATUS_SHADOWED)
                    visibleLights[numVisibleLights++] = i;
            }
            if (numVisibleLights > 0) {
                for (int i = 0; i < options.getNumLightSamples(); i++) {
                    LightSample sample = new LightSample();
                    sample.getRadiance().set(Color.BLACK);
                    // pick a light among the visible lights
                    double[] rnd = sampler.getNext();
                    double rndx = rnd[0];
                    int lidx = (int) (rndx * numVisibleLights);
                    LightSource ls = lights[visibleLights[lidx]];
                    rndx = (rndx * numVisibleLights) - lidx;
                    // pick sample on light source
                    // and set direction
                    ls.getSample(rndx, rnd[1], state, sample);
                    // acount for probability of selecting this light amongst all visible ones
                    sample.getRadiance().mul((double) numVisibleLights / (double) options.getNumLightSamples());
                    if (shadowStatus[lidx] == STATUS_FULL_ILLUM)
                        sample.setShadowRay(null);
                    state.addSample(sample);
                }
            }
        }
        if (options.computeCaustics() && getCaustics)
            causticPhotonMap.getNearestPhotons(state);
    }

    boolean isShadowed(LightSample sample) {
        // is in shadow?
        if (sample.isValid())
            if (!options.traceShadows() || (sample.getShadowRay() == null) || !intAccel.intersects(sample.getShadowRay()))
                return false;

        return true;
    }
}