Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Copyright [2013-2014] eBay Software Foundation
   *  
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *    http://www.apache.org/licenses/LICENSE-2.0
   *  
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package ml.shifu.guagua.mapreduce.example.nn.meta;
 
 
NNParams are used to save NN model info which can also be stored into ZooKeeper.

weights is used to set model weights which is used to transfer info from master to workers.

gradients is used to accumulate all workers' gradients together in master and then use the accumulated gradients to update neural network weights.

 
 public class NNParams extends HaltBytable {

    
Weights used for NN model
 
     private double[] weights;

    
Gradients for NN model
 
     private double[] gradients;

    
Current test error which can be sent to master
 
     private double testError = 0;

    
Current train error which can be sent to master
 
     private double trainError = 0;

    
Training size of each worker and master
 
     private long trainSize = 0;
 
     public double[] getWeights() {
         return ;
     }
 
     public void setWeights(double[] weights) {
         this. = weights;
     }
 
     public double getTestError() {
         return ;
     }
 
     public void setTestError(double testError) {
         this. = testError;
     }
 
     public double getTrainError() {
         return ;
     }
 
     public void setTrainError(double trainError) {
         this. = trainError;
     }
 
     public void accumulateGradients(double[] gradients) {
         if(this. == null) {
             this. = new double[gradients.length];
             Arrays.fill(this., 0.0);
         }
 
         if(this. == null) {
             this. = new double[gradients.length];
             NNUtils.randomize(gradients.lengththis.);
         }
 
         for(int i = 0; i < gradients.lengthi++) {
             this.[i] += gradients[i];
        }
    }

    

Returns:
the gradients
    public double[] getGradients() {
        return ;
    }

    

Parameters:
gradients the gradients to set
    public void setGradients(double[] gradients) {
        this. = gradients;
    }
    public long getTrainSize() {
        return ;
    }
    public void setTrainSize(long trainSize) {
        this. = trainSize;
    }
    public void accumulateTrainSize(long size) {
        this. = this.getTrainSize() + size;
    }
    public void reset() {
        this.setTrainSize(0);
        if(this. != null) {
            Arrays.fill(this., 0.0);
        }
    }
    @Override
    public void doWrite(DataOutput outthrows IOException {
        out.writeDouble(getTrainError());
        out.writeDouble(getTestError());
        out.writeLong(getTrainSize());
        out.writeInt(getWeights().length);
        for(double weightgetWeights()) {
            out.writeDouble(weight);
        }
        out.writeInt(getGradients().length);
        for(double gradientgetGradients()) {
            out.writeDouble(gradient);
        }
    }
    @Override
    public void doReadFields(DataInput inthrows IOException {
        this. = in.readDouble();
        this. = in.readDouble();
        this. = in.readLong();
        int len = in.readInt();
        double[] weights = new double[len];
        for(int i = 0; i < leni++) {
            weights[i] = in.readDouble();
        }
        this. = weights;
        len = in.readInt();
        double[] gradients = new double[len];
        for(int i = 0; i < leni++) {
            gradients[i] = in.readDouble();
        }
        this. = gradients;
    }
    @Override
    public String toString() {
        return String.format("NNParams [testError=%s, trainError=%s, trainSize=%s, weights=%s, gradients%s]",
                this.this.this., Arrays.toString(this.),
                Arrays.toString(this.));
    }
New to GrepCode? Check out our FAQ X