/*
 * LeastSquares.java
 *
 * (C) 2008, Alex S.
 */
 
import java.awt.*;
import java.util.*;

/**
 * very basic least squares
 */
public class LeastSquares extends BufferedApplet
{
    private int width,height;

    // points are read into this thing.
    private Vector points = new Vector();

    // matrices to easily move from model to screen and vice versa.
    private pMatrix toscreen = null;
    private pMatrix tomodel = null;

    /**
     * initialize things
     */
    public void init(){
        super.init();
        width = bounds().width;
        height = bounds().height;
        tomodel = pMatrix.reflecty3d().mult(pMatrix.translate3d(-width/2,-height/2,0));
        toscreen = tomodel.inv3d();
    }

    /**
     * given all points in `points' vector; use 
     * the ``Least Squares'' method to get a line.
     */
    public double[] solve(){
        // make the X (samples) matrix.
        pMatrix X = new pMatrix(points.size(),2);
        pMatrix Y = new pMatrix(points.size(),1);
        for(int i=0;i<points.size();i++){
            double[] p = (double[])points.elementAt(i);
            X.set(i,0,p[0]);    X.set(i,1,1);
            Y.set(i,0,p[1]);
        }

        // Ridge Regression
        //
        // below are two identical ways of solving this: primal and dual.
        // 
        // the primal method is faster when dimensions are small (ie: this case)
        // dual method is faster when number of dimensions is high.
        //
        // dual also doesn't require actual inputs; but an inner (dot) product
        // between input points (allowing the use of kernel functions).
        
        // primal solution
        // w = (X^T * X + aI)^-1 * X^T * y
        pMatrix Xt = X.T();
        pMatrix G = Xt.mult(X);
        pMatrix w = G.add( G.I().smult(0.01) ).inv().mult(Xt).mult(Y);
        
        /*
        // dual solution
        // alpha = (X * X^T + aI)^-1 * y
        // w = X^T * alpha
        pMatrix Xt = X.T();
        pMatrix G = X.mult(Xt);         // Gram matrix
        pMatrix alpha = G.add( G.I().smult(0.01) ).inv().mult(Y);   // Lagrange multipliers
        pMatrix w = Xt.mult(alpha); // weights
        */

        return new double[]{ w.get(0,0), w.get(1,0) };
    }

    /**
     * function that gets called whenever something needs to be
     * rendered.
     */
    public void render(Graphics g) {
        if (damage) {
            // clear screen
            g.setColor(Color.white);
            g.fillRect(0, 0, width, height);

            // draw axis
            g.setColor(Color.black);
            g.drawLine(width/2,0,width/2,height);
            g.drawLine(0,height/2,width,height/2);

            // draw dots.
            g.setColor(Color.darkGray);
            Enumeration e = points.elements();
            while(e.hasMoreElements()){
                double[] p = toscreen.mult((double[])e.nextElement());
                g.fillOval(((int)p[0])-3,((int)p[1])-3,6,6);
            }

            // solve for line; Ax + D = y
            double[] line = solve();
            // render line.
            double[] oldp = toscreen.mult(new double[]{ -width, line[0] * -width + line[1], 0, 1 });
            g.setColor(Color.blue);
            for(int x=-width;x<width*2;x++){
                double[] newp = toscreen.mult(new double[]{ x,line[0]*x + line[1], 0, 1 });
                g.drawLine((int)oldp[0],(int)oldp[1],(int)newp[0],(int)newp[1]);
                oldp = newp;
            }
            g.drawString("y = "+(Math.round(line[0]*100.0)/100.0)+
                " * x + "+(Math.round(line[1]*100.0)/100.0),10,10);
            g.setColor(Color.green);
            g.drawRect(0, 0, width-1, height-1);
        }
    }

    // mouse click adds point.
    public boolean mouseDown(Event evt, int x, int y){
        points.addElement(tomodel.mult(new double[] { x, y, 0, 1 }));
        return damage = true;
    }

    // any key resets
    public boolean keyDown(Event evt, int key){
        points.removeAllElements();
        return damage = true;
    }
}

