/*
 * LeastSquaresDiscriminator.java
 *
 * (C) 2008, Alex S.
 */
 
import java.awt.*;
import java.util.*;

/**
 * Least Squares Discriminator
 */
public class LeastSquaresDiscriminator extends BufferedApplet
{
    private int width,height;

    // points are read into this thing.
    private Vector points = new Vector();

    // matrices to easily move from model to screen and vice versa.
    private pMatrix toscreen = null;
    private pMatrix tomodel = null;

    CheckboxGroup checkboxes = null;
        
    /**
     * initialize things
     */
    public void init(){
        super.init();
        width = bounds().width;
        height = bounds().height;
        tomodel = pMatrix.reflecty3d().mult(pMatrix.translate3d(-width/2,-height/2,0));
        toscreen = tomodel.inv3d();

        checkboxes = new CheckboxGroup();
        Panel p = new Panel();
        p.setLayout(new FlowLayout());
        p.add(new Checkbox("x", checkboxes, true));
        p.add(new Checkbox("o", checkboxes, false));
        setLayout(new BorderLayout());
        add(BorderLayout.SOUTH,p);
    }

    /**
     * find line equation with maximal margin that best classifies
     */
    public double[] solve(){
        // make the X (samples) matrix.
        pMatrix X = new pMatrix(points.size(),2);
        pMatrix Y = new pMatrix(points.size(),1);
        for(int i=0;i<points.size();i++){
            double[] p = (double[])points.elementAt(i);
            X.set(i,0,p[0]);    X.set(i,1,p[1]);
            Y.set(i,0,p[2]);
        }

        // dual solution
        pMatrix Xt = X.T();             // transpose
        pMatrix Yt = Y.T();
        pMatrix G = X.mult(Xt);         // Gram matrix; XX^T
        pMatrix yy = Y.mult(Yt);        // yy^T
        pMatrix H = G.smult(yy);        // Hessian matrix; Hij = yi*yj*XX^T

        pMatrix kkt = new pMatrix(H.rows+1,H.cols+1);   // Karush-Kuhn-Tucker 
        kkt.set(0,0,0);
        kkt.set(0,1,Yt.get(0,0,1,Yt.cols));
        kkt.set(1,0,Y.get(0,0,Y.rows,1));
        kkt.set(1,1,H);

        pMatrix z1 = new pMatrix(kkt.rows,1);   // [0,1,1,1...]
        z1.set(1);
        z1.set(0,0,0);

        pMatrix sol = kkt.add( kkt.I().smult(0.001) ).inv().mult(z1);
        double d = sol.get(0,0);                    // threshold.
        pMatrix alpha = sol.get(1,0,sol.rows-1,1);  // Lagrange multipliers
        // samples with highest alpha contribute more to the resulting line.

        pMatrix w = Xt.mult(alpha.smult(Y));        // weights.

        return new double[]{w.get(0,0), w.get(0,1), -d };   // Ax + By - D = 0
    }

    /**
     * function that gets called whenever something needs to be
     * rendered.
     */
    public void render(Graphics g) {
        if (damage) {
            // clear screen
            g.setColor(Color.white);
            g.fillRect(0, 0, width, height);

            // Ax + By - D = 0
            double[] line = solve();
            if(line[1] == 0)
                line[1] = 0.0001;

            // shade background according to classification line.
            for(int y=0;y<=height;y+=4){
                for(int x=0;x<=width;x+=4){
                    double[] p = tomodel.mult(new double[] { x-2, y-2,0,1} );
                    double v = line[0] * p[0] + line[1] * p[1] - line[2];
                    g.setColor(v < 0 ? Color.gray : Color.lightGray);
                    g.fillRect(x-2, y-2, 4, 4);
                }
            }

            // draw axis
            g.setColor(Color.black);
            g.drawLine(width/2,0,width/2,height);
            g.drawLine(0,height/2,width,height/2);

            // draw x/o
            g.setColor(Color.black);
            Enumeration e = points.elements();
            while(e.hasMoreElements()){
                double[] p = toscreen.mult((double[])e.nextElement());
                g.drawString(p[2] > 0 ? "x" : "o",(int)p[0],(int)p[1]);
            }

            // render line; y = (D - Ax)/B
            double[] oldp = toscreen.mult(new double[]{ -width, (line[2] - line[0]*-width) / line[1], 0, 1 });
            g.setColor(Color.blue);
            for(int x=-width;x<width*2;x++){
                double[] newp = toscreen.mult(new double[]{ x,(line[2] - line[0]*x) / line[1], 0, 1 });
                g.drawLine((int)oldp[0],(int)oldp[1],(int)newp[0],(int)newp[1]);
                oldp = newp;
            }
            g.drawString(""+line[0]+" * x + "+line[1]+" * y = "+(-line[2]),10,10);

            g.setColor(Color.green);
            g.drawRect(0, 0, width-1, height-1);
        }
    }

    // mouse click adds point.
    public boolean mouseDown(Event evt, int x, int y){
        boolean plotx = checkboxes.getSelectedCheckbox().getLabel().equals("x") ? true:false;
        plotx = (evt.modifiers & Event.META_MASK) == 0 ? plotx : !plotx;
        points.addElement(tomodel.mult(new double[] { x, y, plotx ? +1 : -1, 1 }));
        return damage = true;
    }

    // any key resets
    public boolean keyDown(Event evt, int key){
        points.removeAllElements();
        return damage = true;
    }
}

