Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
Copyright [2013-2014] eBay Software Foundation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
 
 package ml.shifu.guagua.mapreduce.example.nn;
 
 
Gradient is copied from Encog framework. The reason is that we original Gradient don't pop up gradients outside. While we need gradients accumulated into NNMaster to update NN weights.
 
 public class Gradient {

    
The network to train.
 
     private FlatNetwork network;

    
The error calculation method.
 
     private final ErrorCalculation errorCalculation = new ErrorCalculation();

    
The actual values from the neural network.
 
     private final double[] actual;

    
The deltas for each layer.
 
     private final double[] layerDelta;

    
The neuron counts, per layer.
 
     private final int[] layerCounts;

    
The feed counts, per layer.
 
     private final int[] layerFeedCounts;

    
The layer indexes.
 
     private final int[] layerIndex;

    
The index to each layer's weights and thresholds.
 
     private final int[] weightIndex;

    
The output from each layer.
 
     private final double[] layerOutput;

    
The sums.
 
     private final double[] layerSums;

    
The gradients.
 
     private double[] gradients;

    
The weights and thresholds.
 
     private double[] weights;

    
The pair to use for training.
 
     private final MLDataPair pair;

    
The training data.
    private final MLDataSet training;

    
error
    private double error;

    
Derivative add constant. Used to combat flat spot.
    private double[] flatSpot;

    
The error function to use.
    private final ErrorFunction errorFunction;

    
Construct a gradient worker.

Parameters:
theNetwork The network to train.
theOwner The owner that is doing the training.
theTraining The training data.
theLow The low index to use in the training data.
theHigh The high index to use in the training data.
    public Gradient(final FlatNetwork theNetworkfinal MLDataSet theTrainingfinal double[] flatSpotErrorFunction ef) {
        this. = theNetwork;
        this. = theTraining;
        this. = flatSpot;
        this. = ef;
        this. = new double[getNetwork().getLayerOutput().length];
        this. = new double[getNetwork().getWeights().length];
        this. = new double[getNetwork().getOutputCount()];
        this. = getNetwork().getWeights();
        this. = getNetwork().getLayerIndex();
        this. = getNetwork().getLayerCounts();
        this. = getNetwork().getWeightIndex();
        this. = getNetwork().getLayerOutput();
        this. = getNetwork().getLayerSums();
        this. = getNetwork().getLayerFeedCounts();
        this. = BasicMLDataPair.createPair(getNetwork().getInputCount(), getNetwork().getOutputCount());
    }

    
Process one training set element.

Parameters:
input The network input.
ideal The ideal values.
s The significance.
    private void process(final double[] inputfinal double[] idealdouble s) {
        this.getNetwork().compute(inputthis.);
        this..updateError(this.ideals);
        this..calculateError(idealthis.getLayerDelta());
        for(int i = 0; i < this..lengthi++) {
            this.getLayerDelta()[i] = ((this.getNetwork().getActivationFunctions()[0].derivativeFunction(
                    this.[i], this.[i]) + this.[0])) * (this.getLayerDelta()[i] * s);
        }
        for(int i = this.getNetwork().getBeginTraining(); i < this.getNetwork().getEndTraining(); i++) {
            processLevel(i);
        }
    }

    
Process one level.

Parameters:
currentLevel The level.
    private void processLevel(final int currentLevel) {
        final int fromLayerIndex = this.[currentLevel + 1];
        final int toLayerIndex = this.[currentLevel];
        final int fromLayerSize = this.[currentLevel + 1];
        final int toLayerSize = this.[currentLevel];
        final int index = this.[currentLevel];
        final ActivationFunction activation = this.getNetwork().getActivationFunctions()[currentLevel + 1];
        final double currentFlatSpot = this.[currentLevel + 1];
        // handle weights
        int yi = fromLayerIndex;
        for(int y = 0; y < fromLayerSizey++) {
            final double output = this.[yi];
            double sum = 0;
            int xi = toLayerIndex;
            int wi = index + y;
            for(int x = 0; x < toLayerSizex++) {
                this.[wi] += output * this.getLayerDelta()[xi];
                sum += this.[wi] * this.getLayerDelta()[xi];
                wi += fromLayerSize;
                xi++;
            }
            this.getLayerDelta()[yi] = sum
                    * (activation.derivativeFunction(this.[yi], this.[yi]) + currentFlatSpot);
            yi++;
        }
    }

    
Perform the gradient calculation
    public final void run() {
        try {
            // reset errors and gradients firstly
            this..reset();
            Arrays.fill(this., 0.0);
            for(int i = 0; i < this..getRecordCount(); i++) {
                this..getRecord(ithis.);
                process(this..getInputArray(), this..getIdealArray(), .getSignificance());
            }
            this. = this..calculate();
        } catch (final Throwable ex) {
            throw new RuntimeException(ex);
        }
    }
        return ;
    }

    

Returns:
the gradients
    public double[] getGradients() {
        return this.;
    }

    

Returns:
the error
    public double getError() {
        return ;
    }

    

Returns:
the weights
    public double[] getWeights() {
        return ;
    }

    

Parameters:
weights the weights to set
    public void setWeights(double[] weights) {
        this. = weights;
        this.getNetwork().setWeights(weights);
    }
    public void setParams(BasicNetwork network) {
        this. = network.getFlat();
        this. = network.getFlat().getWeights();
    }
    public FlatNetwork getNetwork() {
        return ;
    }
    public double[] getLayerDelta() {
        return ;
    }
New to GrepCode? Check out our FAQ X