/*
 * Kmeans.java
 *
 * (C) 2008, Alex S.
 */
 
import java.awt.*;
import java.util.*;

/**
 * K-means, with K=4.
 *
 * This is an EM (Expectation Maximization) algorithm: 2 steps.
 *
 * In the first step (Expectation), we assign each point to a mean.
 * In the second step (Maximization), we recompute the mean from all assigned points.
 * repeat until the means are stable.
 */
public class Kmeans extends BufferedApplet
{
    private int width,height;

    // points are read into this thing.
    private Vector points = new Vector();
    private Vector classes = new Vector();  // point classes.

    // means/string/color
    private double[][] means = new double[4][2];
    private String[] meansstr = new String[]{"x","o","w","a"};
    private Color[] meanscolors = new Color[]{Color.red, Color.lightGray, Color.green, Color.blue };

    // matrices to easily move from model to screen and vice versa.
    private pMatrix toscreen = null;
    private pMatrix tomodel = null;

    /**
     * initialize things
     */
    public void init(){
        super.init();
        width = bounds().width;
        height = bounds().height;
        tomodel = pMatrix.reflecty3d().mult(pMatrix.translate3d(-width/2,-height/2,0));
        toscreen = tomodel.inv3d();

        // generate random means.
        for(int i=0;i<means.length;i++){
            means[i][0] = Math.random()*width - width/2;
            means[i][1] = Math.random()*height - height/2;
        }
    }

    /**
     * function that gets called whenever something needs to be
     * rendered.
     */
    public void render(Graphics g) {
        if (damage) {
            
            // clear screen
            g.setColor(Color.white);
            g.fillRect(0, 0, width, height);

            boolean swapped = true;
            while(swapped) {
                // loop while means are moving.
                swapped = false;

                // Expectation: assign each point to a mean (the closest one).
                // for each point, find the closest mean point.
                for(int i=0;i<points.size();i++){
                    double[] p = (double[])points.elementAt(i);
                    // for all means...
                    int minmean = 0;
                    double mindist = 999999999;
                    for(int j=0;j<means.length;j++){
                        double[] q = means[j];
                        double dist = (p[0]-q[0])*(p[0]-q[0]) + (p[1]-q[1])*(p[1]-q[1]);
                        if(dist < mindist){
                            minmean = j;
                            mindist = dist;
                        }
                    }
                    // assign point to closest mean.
                    Integer cls = (Integer)classes.elementAt(i);
                    if(cls == null || cls.intValue() != minmean){
                        // point i gets assigned minmean class.
                        classes.setElementAt(new Integer(minmean),i);
                        swapped = true;     // did the class move?
                    }
                }
                
                // Maximization: adjust means to be true means of the data.
                // for all means, for all points assigned to that mean,
                // compute average (mean).
                for(int i=0;i<means.length;i++){
                    int cnt = 0;
                    means[i][0] = means[i][1] = 0;
                    for(int j=0;j<points.size();j++){
                        Integer cls = (Integer)classes.elementAt(j);
                        if(cls.intValue() == i){
                            double[] p = (double[])points.elementAt(j);
                            means[i][0] += p[0];
                            means[i][1] += p[1];
                            cnt++;
                        }
                    }
                    if(cnt > 0){
                        means[i][0] /= cnt;
                        means[i][1] /= cnt; 
                    }else{
                        // if nothing assigned to point, then make random mean.
                        means[i][0] = Math.random()*width - width/2;
                        means[i][1] = Math.random()*height - height/2;
                    }
                }
            }


            // shade background according to classification.
            for(int y=0;y<=height;y+=4){
                for(int x=0;x<=width;x+=4){
                    double[] p = tomodel.mult(new double[] { x-2, y-2,0,1 } );
                    int minmean = 0;
                    double mindist = 999999999;
                    // find the closest mean.
                    for(int i=0;i<means.length;i++){
                        double[] q = (double[])means[i];
                        double dist = (p[0]-q[0])*(p[0]-q[0]) + (p[1]-q[1])*(p[1]-q[1]);
                        if(dist < mindist){
                            minmean = i;
                            mindist = dist;
                        }
                    } 
                    g.setColor( meanscolors[minmean] );
                    g.fillRect(x-2, y-2, 4, 4);
                }
            }

            // draw axis
            g.setColor(Color.black);
            g.drawLine(width/2,0,width/2,height);
            g.drawLine(0,height/2,width,height/2);

            // draw points.
            g.setColor(Color.black);
            for(int i=0;i<points.size();i++){
                double[] p = toscreen.mult((double[])points.elementAt(i));
                int cls = ((Integer)classes.elementAt(i)).intValue();
                g.drawString(meansstr[cls],(int)p[0],(int)p[1]);
            }

            g.setColor(Color.black);
            g.drawRect(0, 0, width-1, height-1);
        }
    }

    // mouse click adds point.
    public boolean mouseDown(Event evt, int x, int y){
        points.addElement(tomodel.mult(new double[] { x, y, 0, 1 }));
        classes.addElement(new Integer((int)(Math.random()*means.length)));
        return damage = true;
    }

    // any key resets
    public boolean keyDown(Event evt, int key){
        points.removeAllElements();
        classes.removeAllElements();
        // generate random means.
        for(int i=0;i<means.length;i++){
            means[i][0] = Math.random()*width - width/2;
            means[i][1] = Math.random()*height - height/2;
        }
        return damage = true;
    }
}

