/*
 * Hopfield.java
 *
 * (C) 2008-2009, Alex S.
 */
 
import java.awt.*;
import java.awt.event.*;
import java.util.*;

/**
 * most of this code is java gui spaghetti code.
 *
 * The relevant key pieces are: 
 *  learnmemories(): learn network weights from a list of memories.
 *  recalc(): using memories, try to reconstruct input. 
 *
 */

/**
 * class to display matrix
 */
class HopMemoryView extends Canvas {

    pMatrix m;              // the data
    int box;                // how big should these be?
    boolean grayscale;      // for displaying weights.

    public HopMemoryView(pMatrix matrix,int boxsize,boolean updatable,boolean grayscale){
        super();
        m = matrix;
        // have matrix tell us when it is changed.
        m.addListener(
            new pMatrix.Listener(){
                public void update(){
                    repaint();
                }
            }
        );

        box = boxsize;
        this.grayscale = grayscale;
        setSize(m.cols * box,m.rows * box);
        
        // if component is clickable; allow user to modify matrix.
        if(updatable){
            addMouseListener(
                new MouseAdapter() {
                    public void mouseClicked(MouseEvent e) {
                        int i = e.getY() / box, j = e.getX() / box;
                        m.set(i,j, m.get(i,j) > 0 ? -1 : 1);
                        m.update();         // nodify gui.
                    }
                }
            );
        }
	setVisible(true);
        repaint();      // calls update (eventually).
    }
    
    public void update(Graphics g){
        paint(g);
    }

    public void paint(Graphics g){
        // in case we'll do grayscale, figure out proper color adjustements.
        double[] minmax = m.minmax();
        double min = minmax[0];
        double max = minmax[1]; 
        
        // draw each matrix element as rectangles.
        for(int i=0;i<m.rows;i++){
            for(int j=0;j<m.cols;j++){
                if(grayscale){
                    int col = 0;
                    if((max - min) > 0)
                        col = (int)( (max - m.get(i,j)) / ( max-min ) * 0xFF ) & 0xFF;
                    g.setColor(new Color(col,col,col));
                }else{
                    // black and white mode.
                    g.setColor(m.get(i,j) > 0 ? Color.black : Color.white);
                }
                g.fillRect(j*box,i*box,box,box);
            }
        }
        g.setColor(Color.darkGray);
        g.drawRect(0,0,m.cols*box-1,m.rows*box-1);
    }
}

/**
 * Hopfield
 */
public class Hopfield extends java.applet.Applet {

    pList memories = null;
    pMatrix inmatrix = null;
    pMatrix outmatrix = null;
    pMatrix hopnet = null;

    /**
     * initialize things
     */
    public void init(){
        memories = new pList();
        setLayout(new BorderLayout());
        Canvas canv = null;

        // setup memory gui
        Panel p = new Panel();
        for(int i=0;i<3;i++){       // 3 memories.
            pMatrix m = new pMatrix(8,8);
            memories.push(m);
            p.add(canv = new HopMemoryView(m,10,true,false));
            canv.addMouseListener(
                new MouseAdapter() {
                    public void mouseClicked(MouseEvent e) {
                        learnmemories();
                    }
                }
            );
        }
        add(BorderLayout.NORTH,packageWithLabel("Memories (you can change these by clicking)",p));
        
        // initial memories.
        initMemoriesToSensibleValues();

        // add input area, hopfield net, and output.
        inmatrix = ((pMatrix)memories.get(0)).dup();        // init input to first memory.
        outmatrix = new pMatrix(8,8);
        hopnet = new pMatrix(8*8,8*8);

        // add noise (flip bits) to input (with 10% probability).
        for(int r=0;r<inmatrix.rows;r++)
            for(int c=0;c<inmatrix.cols;c++)
                if(Math.random() > 0.90)
                    inmatrix.set(r,c,inmatrix.get(r,c) * -1);
        
        p = new Panel();
        p.add(packageWithLabel("Input (click to change)",canv = new HopMemoryView(inmatrix,16,true,false)));
        canv.addMouseListener(
            new MouseAdapter() {
                public void mouseClicked(MouseEvent e) {
                    recalc();
                }
            }
        );
        p.add(packageWithLabel("Hopfield Net Weights",new HopMemoryView(hopnet,2,false,true)));        
        add(BorderLayout.CENTER,p);

        p = new Panel();
        p.add(packageWithLabel("Output (reconstructing input from memories)",new HopMemoryView(outmatrix,16,false,false)));
        add(BorderLayout.SOUTH,p);

        learnmemories();
    }

    // craptatural java gui api crap.
    public Component packageWithLabel(String s, Component c){
        Panel p = new Panel();
        p.setLayout(new BorderLayout());
        p.add(BorderLayout.NORTH,new Label(s,Label.CENTER));
        Panel q = new Panel();
        q.setLayout(new FlowLayout(FlowLayout.CENTER));
        q.add(c);
        p.add(BorderLayout.CENTER,q);
        return p;
    }

    // whatever the function says.
    public void initMemoriesToSensibleValues(){
        int[][] alphabet = {
            {0x38,0x6C,0xC6,0xC6,0xFE,0xC6,0xC6,0x00},      // A
            {0xFC,0x6E,0x66,0x7C,0x66,0x6E,0xFC,0x00},      // B
            {0x3E,0x62,0xC0,0xC0,0xC0,0x62,0x3E,0x00},      // C
            {0xF8,0x6E,0x66,0x66,0x66,0x6E,0xF8,0x00},      // D
            {0xFE,0x62,0x60,0x78,0x60,0x62,0xFE,0x00},      // E
            {0xFE,0x62,0x60,0x78,0x60,0x60,0xF0,0x00}       // F
        };
        for(int k=0;k<memories.size();k++){
            pMatrix m = (pMatrix)memories.get(k);
            for(int r=0;r<m.rows;r++){
                for(int c=0;c<m.cols;c++){
                    m.set(r,7-c,(alphabet[k][r] & (1<<c)) > 0 ? 1 : -1);
                }
            }
        }
    }

    /**
     * try to reconstruct one of the memories using input.
     */
    public void recalc(){
        // output is the new input :-)
        outmatrix.set(0,0,inmatrix);

        // activations.
        pMatrix act = new pMatrix(outmatrix.rows,outmatrix.cols);

        int maxloops = 50;
        int noloops = 0;
        
        pMatrix last = null;
        // loop while we're still changing stuff 
        // (last iteration != to current iteration?)
        do {
            last = outmatrix.dup();

            // compute activation
            for(int i=0;i<hopnet.rows;i++){
                double a = 0;
                for(int j=0;j<hopnet.cols;j++){
                    a += hopnet.get(i,j) * outmatrix.get(j);
                }
                act.set((int)(i/act.rows),(int)(i%act.cols),a);
            }

            // apply threshold and update
            for(int i=0;i<act.rows;i++)
                for(int j=0;j<act.cols;j++)
                    outmatrix.set(i,j,act.get(i,j) < 0 ? -1:1); 
            
            // sum of squares (see if we changed values).
        }while(last.subtractEq(outmatrix).smult(last).sum() > 0.1 && noloops++ < maxloops);

        // nodify gui.
        outmatrix.update();
    }

    /**
     * apply Hebbian learning rule to adjust weights to fit with memories.
     * the weight between nodes x,y is essentially the correlation of x,y.
     */
    public void learnmemories(){

        // learning rule (learn the memories)
        for(int i=0;i<hopnet.rows;i++){ 
            for(int j=0;j<hopnet.cols;j++){
                hopnet.set(i,j,0);
                if(i==j)
                    continue;
                double sum = 0;
                for(int k=0;k<memories.size();k++){
                    pMatrix m = (pMatrix)memories.get(k);
                    sum += m.get(i) * m.get(j);
                }
                hopnet.set(i,j,sum);
            }
        }
        // notify gui.
        hopnet.update();

        // recalculate output from inputs (now that the weights are different).
        recalc();
    }

}

