Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Copyright [2013-2014] eBay Software Foundation
   *  
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *    http://www.apache.org/licenses/LICENSE-2.0
   *  
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package ml.shifu.guagua.mapreduce.example.nn;
 
 
 
 import  org.apache.hadoop.io.LongWritable;
 import  org.apache.hadoop.io.Text;
 
NNWorker is used to compute NN model according to splits assigned. The result will be sent to master for accumulation.

Gradients in each worker will be sent to master to update weights of model in worker, which follows Encog's multi-core implementation.

 
 public class NNWorker extends
 
     private static final Logger LOG = LoggerFactory.getLogger(NNWorker.class);

    
Training data set
 
     private MLDataSet trainingData = null;

    
Testing data set
 
     private MLDataSet testingData = null;

    
NN algorithm runner instance.
 
     private Gradient gradient;

    
input record size, inc one by one.
 
     private long count;
 
     private int inputs;
 
     private int hiddens;
 
     private int outputs;

    
Create memory data set object
 
     private void initMemoryDataSet() {
         this. = new BasicMLDataSet();
         this. = new BasicMLDataSet();
     }
 
     @Override
     public void init(WorkerContext<NNParamsNNParamsworkerContext) {
          = NumberFormatUtils.getInt(workerContext.getProps().getProperty(.),
                 .);
          = NumberFormatUtils.getInt(workerContext.getProps().getProperty(.),
                 .);
          = NumberFormatUtils.getInt(workerContext.getProps().getProperty(.),
                 .);
        .info("NNWorker is loading data into memory.");
        initMemoryDataSet();
    }
    @Override
    public NNParams doCompute(WorkerContext<NNParamsNNParamsworkerContext) {
        // For first iteration, we don't do anything, just wait for master to update weights in next iteration. This
        // make sure all workers in the 1st iteration to get the same weights.
        if(workerContext.getCurrentIteration() == 1) {
            return buildEmptyNNParams(workerContext);
        }
        if(workerContext.getLastMasterResult() == null) {
            // This may not happen since master will set initialization weights firstly.
            .warn("Master result of last iteration is null.");
            return null;
        }
        .debug("Set current model with params {}"workerContext.getLastMasterResult());
        // initialize gradients if null
        if( == null) {
            initGradient(this.workerContext.getLastMasterResult().getWeights());
        }
        // using the weights from master to train model in current iteration
        this..setWeights(workerContext.getLastMasterResult().getWeights());
        this..run();
        // get train errors and test errors
        double trainError = this..getError();
        double testError = this..getRecordCount() > 0 ? (this..getNetwork()
                .calculateError(this.)) : 0;
        .info("NNWorker compute iteration {} (train error {} validation error {})",
                new Object[] { workerContext.getCurrentIteration(), trainErrortestError });
        NNParams params = new NNParams();
        params.setTestError(testError);
        params.setTrainError(trainError);
        params.setGradients(this..getGradients());
        // prevent null point;
        params.setWeights(new double[0]);
        params.setTrainSize(this..getRecordCount());
        return params;
    }
    private void initGradient(MLDataSet trainingdouble[] weights) {
        BasicNetwork network = NNUtils.generateNetwork(this.this.this.);
        // use the weights from master
        network.getFlat().setWeights(weights);
        FlatNetwork flat = network.getFlat();
        // copy Propagation from encog
        double[] flatSpot = new double[flat.getActivationFunctions().length];
        for(int i = 0; i < flat.getActivationFunctions().lengthi++) {
            flatSpot[i] = flat.getActivationFunctions()[iinstanceof ActivationSigmoid ? 0.1 : 0.0;
        }
        this. = new Gradient(flattraining.openAdditional(), flatSpotnew LinearErrorFunction());
    }
    private NNParams buildEmptyNNParams(WorkerContext<NNParamsNNParamsworkerContext) {
        NNParams params = new NNParams();
        params.setWeights(new double[0]);
        params.setGradients(new double[0]);
        params.setTestError(0.0d);
        params.setTrainError(0.0d);
        return params;
    }
    @Override
    protected void postLoad(WorkerContext<NNParamsNNParamsworkerContext) {
        .info("- # Records of the whole data set: {}."this.);
        .info("- # Records of the training data set: {}."this..getRecordCount());
        .info("- # Records of the testing data set: {}."this..getRecordCount());
    }
    @Override
    public void load(GuaguaWritableAdapter<LongWritable> currentKeyGuaguaWritableAdapter<Text> currentValue,
            WorkerContext<NNParamsNNParamsworkerContext) {
        ++this.;
        if((this.) % 100000 == 0) {
            .info("Read {} records."this.);
        }
        // use guava to iterate only once
        double[] ideal = new double[1];
        int inputNodes = NumberFormatUtils.getInt(
                workerContext.getProps().getProperty(.),
                .);
        double[] inputs = new double[inputNodes];
        int i = 0;
        for(String input: Splitter.on(.).split(
                currentValue.getWritable().toString())) {
            if(i == 0) {
                ideal[i++] = NumberFormatUtils.getDouble(input, 0.0d);
            } else {
                int inputsIndex = (i++) - 1;
                if(inputsIndex >= inputNodes) {
                    break;
                }
                inputs[inputsIndex] = NumberFormatUtils.getDouble(input, 0.0d);
            }
        }
        if(i < (inputNodes + 1)) {
            throw new GuaguaRuntimeException(String.format(
                    "Not enough data columns, input nodes setting:%s, data column:%s"inputNodesi));
        }
        int scale = NumberFormatUtils.getInt(workerContext.getProps().getProperty(.), 1);
        for(int j = 0; j < scalej++) {
            double[] tmpInputs = j == 0 ? inputs : new double[inputs.length];
            double[] tmpIdeal = j == 0 ? ideal : new double[ideal.length];
            System.arraycopy(inputs, 0, tmpInputs, 0, inputs.length);
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(tmpInputs), new BasicMLData(tmpIdeal));
            double r = Math.random();
            if(r >= 0.5d) {
                this..add(pair);
            } else {
                this..add(pair);
            }
        }
    }
    /*
     * (non-Javadoc)
     * 
     * @see ml.shifu.guagua.worker.AbstractWorkerComputable#initRecordReader(ml.shifu.guagua.io.GuaguaFileSplit)
     */
    @Override
    public void initRecordReader(GuaguaFileSplit fileSplitthrows IOException {
        this.setRecordReader(new GuaguaLineRecordReader());
        this.getRecordReader().initialize(fileSplit);
    }
    public MLDataSet getTrainingData() {
        return ;
    }
    public void setTrainingData(MLDataSet trainingData) {
        this. = trainingData;
    }
    public MLDataSet getTestingData() {
        return ;
    }
    public void setTestingData(MLDataSet testingData) {
        this. = testingData;
    }
New to GrepCode? Check out our FAQ X