/*

  Backpropagating Neural Network java application
  by Scott Teresi
  www.teresi.us

  Developed at Univ. of Toledo and Univ. of Illinois
  May-October 1997

*/


import java.awt.*;
import java.io.*;


public class Net extends Frame {

  public static NetInterface netInterface;
  public static StatusWindow statusWindow;
  public static OutputsWindow outputsWindow;

  public static int MAX_LAYERS = 5;    // max no. of layers allowed
  public static int MAX_NODES = 100;   // max no. of nodes allowed per layer
  public static int MAX_SETS = 500;    // max no. of training sets allowed
  public static int window_x = 600;    // default size of window on screen
  public static int window_y = 440;
  public static float learn_rate = (float) 0.05;
  public static int iterations = 10000;
  public static int cur_iter = 0;
  public static StringBuffer inputFileNm = new StringBuffer
                                           ("data/glass_norm.train");
  public static boolean boolean_func = false;  // irrelevant for tanh function

  
  public static int num_layers = 3; // no. of layers of neurons in the network
  public static int max_nodes = 24; // max no. of nodes in any layer
  public static int input_nodes;
  public static int output_nodes;   // no. of nodes in input and output layers
  public static int training_sets;  // sets of inputs/outputs for training
  public static int num_nodes[] = new int[MAX_LAYERS+1];
                               // array of no. of nodes (neurons) in each layer
  public static LearnData[] learndata;
                               // array of training data
  public static Neuron[][] neuron = new Neuron[MAX_LAYERS+1][MAX_NODES+1];
                               // array for the net of neurons
//  public static Neuron[][] neuron;
  private static Graphics screen;  // screen on which to draw the net




  // (This is an application)

  public static void main (String args[]) {

    Net net = new Net();
    net.init();
    net.start();
  }



  // Initial values

  public void init () {
    int i;

    learndata = new LearnData[MAX_SETS + 1];
    // approx. the no. of lines in whatever training set data file is read in

    num_nodes[1] = 2;
    num_nodes[2] = 5;
    num_nodes[3] = 1;
    for (i = 4; i < MAX_LAYERS+1; i ++)
      num_nodes[i] = 0;

    input_nodes = num_nodes[1];
    output_nodes = num_nodes[num_layers];
  }



  public void start () {
      // create status window
    // statusWindow = new StatusWindow(this);   // StatusWindow removed
    // statusWindow.show();
    // statusWindow.resize(120, 170);
      // create main screen
    netInterface = new NetInterface(this);
    netInterface.show();
    netInterface.resize (window_x, window_y);
    screen = netInterface.getGraphics();
      // initialize and display the net
    initNetwork();
  }



  // Read Input Data file

  public void readInputFile () {
    int i;
    int lineCount = 0;
    int ignoreFields = 0;
    String line = new String("<empty>");
    int index;
    int extent;
    int charsAvail;
    int old_input_nodes = input_nodes;
    int old_output_nodes = output_nodes;

    try {
      FileInputStream fileIn = new FileInputStream (new String(inputFileNm));
      DataInputStream dataIn = new DataInputStream (fileIn);
      charsAvail = fileIn.available();
      while (charsAvail > 0) {
        line = dataIn.readLine();
        charsAvail = fileIn.available();
        // line below displays each line of the file before it is parsed
        // System.out.println("LINE #" +lineCount + ": " + line);
        lineCount ++;
        if (lineCount == 1)  // read no. of fields to skip at start of ea. line
          ignoreFields = Integer.parseInt(line);
        else if (lineCount == 2)  // read number of inputs
          input_nodes = Integer.parseInt(line);
        else if (lineCount == 3)  // read number of outputs
          output_nodes = Integer.parseInt(line);
        else {
          index = 0;
          extent = 0;
          for (i = 1; i <= ignoreFields; i ++) {
            extent = getExtentOfNum(index, line, ",");
            index = index + extent + 1;
            }
          learndata[lineCount - 3] = new LearnData(input_nodes, output_nodes);
          for (i = 1; i <= input_nodes; i ++) {
            extent = getExtentOfNum(index, line, ",");
            learndata[lineCount - 3].train_input[i] = (Float.valueOf
                (line.substring(index, index + extent))).floatValue();
            index = index + extent + 1;
            }
          for (i = 1; i <= output_nodes; i ++) {
            extent = getExtentOfNum(index, line, ",");
            learndata[lineCount - 3].train_output[i] = (Float.valueOf
                (line.substring(index, index + extent))).floatValue();
            index = index + extent + 1;
            }
          training_sets = lineCount - 3;
          num_nodes[1] = input_nodes;
          num_nodes[num_layers] = output_nodes;
        }
      }
      fileIn.close();
    } catch (IOException endFile) { System.out.println ("File error.");}

/*
    // training data for logical XOR function

    training_sets = 4;

    for (i = 1; i <= training_sets; i ++)
      learndata[i] = new LearnData(input_nodes, output_nodes);

    learndata[1].train_input[1] = -1;
    learndata[1].train_input[2] = -1;
    learndata[1].train_output[1] = -1;

    learndata[2].train_input[1] = -1;
    learndata[2].train_input[2] = 1;
    learndata[2].train_output[1] = 1;

    learndata[3].train_input[1] = 1;
    learndata[3].train_input[2] = -1;
    learndata[3].train_output[1] = 1;

    learndata[4].train_input[1] = 1;
    learndata[4].train_input[2] = 1;
    learndata[4].train_output[1] = 1;

    num_nodes[1] = input_nodes;
    num_nodes[num_layers] = output_nodes;
*/

    for (i = 1; i < training_sets; i ++)
      learndata[i].train_input[0] = 1;  // threshold value

    // clear network if topography is different
    if ((old_input_nodes != input_nodes) || (old_output_nodes != output_nodes))
      initNetwork();

  }



  int getExtentOfNum(int start, String targetStr, String separator) {

  // Find the length of a number up to the next separator (e.g. comma)

    // "start" is the starting position in the string
    // "targetStr" is the string to search within
    // "separator" is what signals the end of the current number

    int ending = start;
    String tempstr = new String("");

    try {
      while (!tempstr.equals(separator)) {
        tempstr = targetStr.substring(ending, ending + 1);
        ending ++;
      }
      return (ending - start - 1);
    } catch (StringIndexOutOfBoundsException ex) {
      if (ending == start) return -1;
      else return (ending - start);
    }
  }


  // Quit program

  public void quitProgram () {
    netInterface.hide();
    netInterface.dispose();
    remove(netInterface);
    // statusWindow.hide();   // StatusWindow removed
    // statusWindow.dispose();
    // remove(statusWindow);
  }


  // Initialize the network

  void initNetwork () {
    createNet();
    calcCoords(window_x, window_y);
    clearNet();
  }


  // Create neuron objects

  void createNet () {
    int layer, node;
//    neuron = new Neuron[num_layers][num_nodes];

    for (layer = 1; layer <= num_layers; layer ++)
      for (node = 0; node <= num_nodes[layer]; node ++)
        neuron[layer][node] = new Neuron(this, layer, node, screen);
  }


  // Clear memory

  void clearNet () {
    int layer, node;
    
    for (layer = 1; layer <= num_layers; layer ++)
      for (node = 0; node <= num_nodes[layer]; node ++) {
        neuron[layer][node].clearWeights();
        if (node == 0)
          neuron[layer][node].setValue(1);  // for the threshold weight
        else
          neuron[layer][node].setValue(0);
      }
  }


  // Initialize layout

  void calcCoords (int wdth, int hght) {
  
    int layer, node;
    int offset_x[] = new int[num_layers+1];  // row's left margin
    int div[] = new int[num_layers+1];       // spacing of each node
    int width[] = new int[num_layers+1];     // width of each node's circle
    int cap_width = 45;                      // largest width of any node
    int v_space = 10;                        // space between node layers
    int acc_space = 0;                  // "height" of all layers, w/ no spaces

    int new_wdth = wdth / 10 * 9;            // size of window to work within  
    int new_hght = hght / 10 * 9;
    int margin_x = (wdth - new_wdth) / 2;    // space around edge of net
    int margin_y = (hght - new_hght);
    wdth = new_wdth;
    hght = new_hght;

    // ensure neuron doesn't look too big
    if ((cap_width > (.2 * wdth)) && (wdth > 25))
      cap_width = (int) (.2 * wdth);
    // divide up space across screen
    do {
      cap_width -= 5;
      for (layer = 1; layer <= num_layers; layer ++) {
        div[layer] = (wdth / num_nodes[layer]);
        width[layer] = div[layer] * 5/6;
        if (width[layer] > cap_width) {
          width[layer] = cap_width;
          div[layer] = width[layer] * 6/5;
        }
        // center the row
        offset_x[layer] = (wdth - (num_nodes[layer] * div[layer])) / 2;
        acc_space += width[layer];
      }
      // equalize space between each row
      v_space = (hght - acc_space) / num_layers;
      acc_space = 0;
    // squash neurons until fit vertically
    } while ((v_space < 5) && !(cap_width <= 5));

    // plot each neuron
    for (layer = 1; layer <= num_layers; layer ++) {
      for (node = 1; node <= num_nodes[layer]; node ++)
        neuron[layer][node].setCoord
            (offset_x[layer] + (node-1) * div[layer] + margin_x,
            layer * v_space + acc_space - v_space/2 + margin_y, width[layer]);
      acc_space += width[layer];
    }
  }


  // Train the net

  void train () {
    int train_set;  // a randomly chosen set of training data

    for (cur_iter = 1; cur_iter <= iterations; cur_iter ++) {
      train_set = 1 + (int) (Math.random() * training_sets);
      if (train_set > training_sets) {
        System.out.println
            ("Error: training set chosen out of bounds (" + train_set + ")");
        train_set = training_sets;
      }
      trainNet(learndata[train_set]);
    }
    cur_iter = 0;

  }

  
  void trainNet (LearnData data) {

    receiveInputs (data.train_input);
    // statusWindow.repaint();       // statusWindow removed
    propagate();
    // statusWindow.repaint();
    backprop (data.train_output);
    // statusWindow.repaint();

  }


  // Run an input through the net

  void runNet (float train_input[]) {
  
    receiveInputs (train_input);
    // statusWindow.repaint();    // statusWindow removed
    propagate();
    // statusWindow.repaint();
  }


  // Place training set inputs on the input layer

  void receiveInputs (float train_input[]) {
    int node;
    
    for (node = 0; node <= input_nodes; node ++)
      neuron[1][node].setValue(train_input[node]);
  }


  // Propagate inputs through net

  // Init. node values to zero, then feed in new weighted inputs
  // from layer i_row to layer j_row

  void propagate () {
    int i_row, j_row, i_node, j_node;

    // for each pair of rows, i and j...
    for (i_row = 1, j_row = 2; j_row <= num_layers; i_row ++, j_row ++) {
      // ...loop through each j_node...
      for (j_node = 1; j_node <= num_nodes[j_row]; j_node ++) {
        neuron[j_row][j_node].setValue(0);  // clear out previous value
        // ...and loop through weights from current j_node to each i_node
        for (i_node = 0; i_node <= num_nodes[i_row]; i_node ++) {
          neuron[j_row][j_node].setValue (neuron[j_row][j_node].getValue() +
              neuron[i_row][i_node].getValue() *
              neuron[j_row][j_node].getWeight(i_node));
          // System.out.println("From row "+j_row+", node "+j_node+
          // " to node "+i_node+": "+neuron[j_row][j_node].getWeight(i_node));
	}
        // apply squashing function to neuron to be ready for more propagating
        //System.out.println("Neuron in row " + j_row + ", node " + j_node +
        //   " has an activation value of "+neuron[j_row][j_node].getValue());
        neuron[j_row][j_node].setValue
	    (tanh(neuron[j_row][j_node].getValue()));
        //System.out.println("Neuron in row " + j_row + ", node " + j_node +
        //   " has been squashed to " + neuron[j_row][j_node].getValue());
      }
    }
  }


  // Backpropagate correct outputs

  void backprop (float train_output[]) {
    float delta[][] = new float[num_layers+1][max_nodes+1];
    int i_row, i_node, j_row, j_node, k_row, k_node;
    
    // initialize deltas to zero

    for (i_row = 1; i_row <= num_layers; i_row ++)
      for (i_node = 1; i_node <= max_nodes; i_node ++)
        delta[i_row][i_node] = 0;

    // step backward through layers and calculate deltas
    
    for (i_row = num_layers-1, j_row = num_layers; i_row >= 1;
                                                   i_row --, j_row --) {
      if (j_row == num_layers)     // if on last row...
        for (j_node = 1; j_node <= num_nodes[j_row]; j_node ++)
          delta[j_row][j_node] =
              (train_output[j_node] - tanh(neuron[j_row][j_node].getValue()))
              * derivTanh(neuron[j_row][j_node].getValue());
      else {                  // otherwise, backpropagate through hidden layers
        k_row = j_row + 1;
        for (j_node = 1; j_node <= num_nodes[j_row]; j_node ++) {
          for (k_node = 1; k_node <= num_nodes[k_row]; k_node ++)
            // get each j_node delta from the k_nodes
            delta[j_row][j_node] += (delta[k_row][k_node] *
                neuron[k_row][k_node].getWeight(j_node));
          delta[j_row][j_node] *= derivTanh(neuron[j_row][j_node].getValue());
        }
      }
    }
    
    // backpropagate weights
    
    for (i_row = num_layers-1, j_row = num_layers; i_row >= 1;
                                                   i_row --, j_row --)
      for (j_node = 1; j_node <= num_nodes[j_row]; j_node ++)
        for (i_node = 1; i_node <= num_nodes[i_row]; i_node ++)
          neuron[j_row][j_node].setWeight
              (i_node, neuron[j_row][j_node].getWeight(i_node)
              + learn_rate * tanh(neuron[i_row][i_node].getValue())
              * delta[j_row][j_node]);
  }



  // Tanh squashing function
  // Better transition for between -1 and 1

  float tanh (float x) {

    double xd = (double) x;
    double a = 1.716;          // (suggested constants from Haykin, p. 160)
    double b = .6666666;

    x = (float) ( (2 * a) / (1 + Math.exp(-b * xd)) - a );
    return x;

  }


  // Derivative of Tanh function

  float derivTanh (float x) {

    x = 1 - (tanh(x) * tanh(x));
    return x;

  }


  // Sigmoid squashing function
  // for a smooth transition between 0 and 1

  float sigmoid (float x) {

    if (!boolean_func)
      x = (float) (1 / (1 + Math.exp ((double) (-x)) ));
    else
      x = x - 1;   // is this correct?
    return x;
  }


  // Derivative of the sigmoid function
  
  float derivSigmoid (float x) {
    double xd = (double) x;

    x = (float) ( sigmoid(x) * (1 - sigmoid(x)) );   // (Haykin p. 149)
    //    x = (float) ((1 + 2 * Math.exp(-xd)) / (1 + 2 * Math.exp(-xd) + Math.exp(-2 * xd)));
    return x;
  }


}

