/*
 * LeastSquaresPoly.java
 *
 * (C) 2008-2009, Alex S.
 */
 
import java.awt.*;
import java.util.*;

/**
 * very basic least squares
 */
public class LeastSquaresPoly extends BufferedApplet
{
    private int width,height;

    // points are read into this thing.
    private Vector points = new Vector();

    // matrices to easily move from model to screen and vice versa.
    private pMatrix toscreen = null;
    private pMatrix tomodel = null;

    /**
     * initialize things
     */
    public void init(){
        super.init();
        width = bounds().width;
        height = bounds().height;
        tomodel = pMatrix.reflecty3d().mult(pMatrix.translate3d(-width/2,-height/2,0));
        toscreen = tomodel.inv3d();
    }

    /**
     * given all points in `points' vector; use 
     * the ``Least Squares'' method to get a polynomial.
     *
     * parameter n is degree of polynomial.
     * n=0 is 0th degree; ie: a constant. y = A.
     * n=1 is 1st degree; ie: line, y = Ax + B
     * n=2 is 2nd degree; ie: parabola, y = Ax^2 + Bx + C
     * n=3 is 3rd degree, y = Ax^3 + Bx^2 + Cx + D
     * ...etc.
     */
    public double[] solve(int n){

        // make the X (samples) matrix.
        pMatrix X = new pMatrix(points.size(),n+1);
        pMatrix Y = new pMatrix(points.size(),1);
        for(int i=0;i<points.size();i++){
            double[] p = (double[])points.elementAt(i);
            for(int j=0;j<=n;j++){
                X.set(i,j,Math.pow(p[0],j));
            }
            Y.set(i,0,p[1]);
        }

        // Ridge Regression
        //
        // below are two identical ways of solving this: primal and dual.
        // 
        // the primal method is faster when dimensions are small (ie: this case)
        // dual method is faster when number of dimensions is high.
        //
        // dual also doesn't require actual inputs; but an inner (dot) product
        // between input points (allowing the use of kernel functions).
        
        // primal solution
        // w = (X^T * X + aI)^-1 * X^T * y
        pMatrix Xt = X.T();
        pMatrix G = Xt.mult(X);
        pMatrix w = G.add( G.I().mult(0.01) ).inv().mult(Xt).mult(Y);
                
        /*
        // dual solution
        // alpha = (X * X^T + aI)^-1 * y
        // w = X^T * alpha
        pMatrix Xt = X.T();
        pMatrix G = X.mult(Xt);         // Gram matrix
        pMatrix alpha = G.add( G.I().mult(0.01) ).inv().mult(Y);   // Lagrange multipliers
        pMatrix w = Xt.mult(alpha); // weights
        */

        double[] r = new double[n+1];
        for(int i=0;i<r.length;i++)
            r[i] = w.get(i,0);
        return r;
    }

    /**
     * function that gets called whenever something needs to be
     * rendered.
     */
    public void render(Graphics g) {
        if (damage) {
            // clear screen
            g.setColor(Color.white);
            g.fillRect(0, 0, width, height);

            // draw axis
            g.setColor(Color.black);
            g.drawLine(width/2,0,width/2,height);
            g.drawLine(0,height/2,width,height/2);

            // draw dots.
            g.setColor(Color.darkGray);
            Enumeration e = points.elements();
            while(e.hasMoreElements()){
                double[] p = toscreen.mult((double[])e.nextElement());
                g.fillOval(((int)p[0])-3,((int)p[1])-3,6,6);
            }

            for(int d=1;d<6;d++){
                double[] line = solve(d);
                // render line.
                double[] oldp = toscreen.mult(new double[]{ -width*2, -width*2, 0, 1 });
                g.setColor(Color.blue);
                for(int x=-width;x<width*2;x++){
                    double s=0;
                    for(int j=0;j<=d;j++)
                        s += line[j] * Math.pow(x,j);
                    double[] newp = toscreen.mult(new double[]{ x, s, 0, 1 });
                    g.drawLine((int)oldp[0],(int)oldp[1],(int)newp[0],(int)newp[1]);
                    oldp = newp;
                }
                String str = "y = ";
                for(int j=d;j>0;j--)
                    str += ""+(Math.round(line[j]*100.0)/100.0)+"x"+(j!=1 ? ("^"+j) : "")+" + ";
                str += ""+(Math.round(line[0]*100.0)/100.0);
                g.drawString(str,10,d*10);
            }

            g.setColor(Color.green);
            g.drawRect(0, 0, width-1, height-1);
        }
    }

    // mouse click adds point.
    public boolean mouseDown(Event evt, int x, int y){
        points.addElement(tomodel.mult(new double[] { x, y, 0, 1 }));
        return damage = true;
    }

    // any key resets
    public boolean keyDown(Event evt, int key){
        points.removeAllElements();
        return damage = true;
    }
}

