/*
 * LeastSquaresNonLinear.java
 *
 * (C) 2008-2009, Alex S.
 */
 
import java.awt.*;
import java.util.*;

/**
 * very basic least squares
 */
public class LeastSquaresNonLinear extends BufferedApplet
{
    private int width,height;

    // points are read into this thing.
    private Vector points = new Vector();

    // matrices to easily move from model to screen and vice versa.
    private pMatrix toscreen = null;
    private pMatrix tomodel = null;

    /**
     * initialize things
     */
    public void init(){
        super.init();
        width = bounds().width;
        height = bounds().height;
        tomodel = pMatrix.reflecty3d().mult(pMatrix.translate3d(-width/4,-height,0));
        toscreen = tomodel.inv3d();
    }

    public pMatrix getPoints(){
        pMatrix xy = new pMatrix(points.size(),2);
        for(int i=0;i<points.size();i++){
            double[] p = (double[])points.elementAt(i);
            xy.set(i,0,p[0]);
            xy.set(i,1,p[1]);
        }
        return xy;
    }

    public double[] fitPoly(int n){
        return pMatrix.fitPoly(getPoints(),n);
    }

    // fit: y = b * e^{a * x}
    // log both sides:  ln(y) = ln(b * e^{a * x}) = ln(b) + a*x
    // linear model: Y=ln(y), X=x
    public double[] fitExp(){
        pMatrix xy = getPoints();
        for(int i=0;i<xy.rows;i++){
            xy.set(i,1, Math.log( xy.get(i,1) ) );
        }
        double[] r = pMatrix.fitPoly(xy,1);
        r[0] = Math.exp(r[0]);
        return r;
    }

    // fit: y = b * x^a
    // log both sides: ln(y) = ln(b) + a*ln(x)
    // linear model: Y=ln(y), X=ln(x)
    public double[] fitPow(){
        pMatrix xy = new pMatrix(points.size(),2);
        for(int i=0;i<points.size();i++){
            double[] p = (double[])points.elementAt(i);
            if(p[0] < 0)        // skip nextive x points.
                continue;
            xy.set(i,0,Math.log(p[0]));
            xy.set(i,1,Math.log(p[1]));
        }    
        double[] r = pMatrix.fitPoly(xy,1);
        r[0] = Math.exp(r[0]);
        return r;        
    }

    public double round(double n, int d){
        double d1 = Math.pow(10,d);
        return (double)Math.round( n * d1 ) / d1;
    }

    /**
     * function that gets called whenever something needs to be
     * rendered.
     */
    public void render(Graphics g) {
        if (damage) {
            // clear screen
            g.setColor(Color.white);
            g.fillRect(0, 0, width, height);

            // draw axis
            g.setColor(Color.black);
            g.drawLine(width/4,0,width/4,height);
            //g.drawLine(0,height/2,width,height/2);

            // draw dots.
            g.setColor(Color.darkGray);
            Enumeration e = points.elements();
            while(e.hasMoreElements()){
                double[] p = toscreen.mult((double[])e.nextElement());
                g.fillOval(((int)p[0])-3,((int)p[1])-3,6,6);
            }

            {
                double[] line = fitExp();
                // render line.
                double[] oldp = toscreen.mult(new double[]{ -width*2, -width*2, 0, 1 });
                g.setColor(Color.blue);
                for(int x=-width;x<width*2;x++){
                    double y=line[0] * Math.exp( line[1] * x );
                    double[] newp = toscreen.mult(new double[]{ x, y, 0, 1 });
                    g.drawLine((int)oldp[0],(int)oldp[1],(int)newp[0],(int)newp[1]);
                    oldp = newp;
                }
                g.drawString("y = "+round(line[0],6)+"* e**("+round(line[1],6)+"x)",10,10);
            }


            {
                double[] line = fitPow();
                // render line.
                double[] oldp = toscreen.mult(new double[]{ 0, 0, 0, 1 });
                g.setColor(Color.red);
                for(int x=0;x<width*2;x++){
                    double y= line[0] * Math.pow(x,line[1]);
                    double[] newp = toscreen.mult(new double[]{ x, y, 0, 1 });
                    g.drawLine((int)oldp[0],(int)oldp[1],(int)newp[0],(int)newp[1]);
                    oldp = newp;
                }
                g.drawString("y = "+round(line[0],6)+"* x**("+round(line[1],6)+")",10,20);
            }


            g.setColor(Color.green);
            g.drawRect(0, 0, width-1, height-1);
        }
    }

    // mouse click adds point.
    public boolean mouseDown(Event evt, int x, int y){
        points.addElement(tomodel.mult(new double[] { x, y, 0, 1 }));
        return damage = true;
    }

    // any key resets
    public boolean keyDown(Event evt, int key){
        points.removeAllElements();
        return damage = true;
    }
}

