Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
Copyright [2013-2014] eBay Software Foundation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
 package ml.shifu.guagua.mapreduce.example.nn;
NNMaster is used to accumulate all workers NN parameters.

We accumulate all gradients from workers to calculate model weights. And set weights to workers. Then workers use weights to set their models and train for another iteration.

This logic follows Encog multi-core implementation.

Make sure workers and master use the same initialization weights.

 public class NNMaster implements MasterComputable<NNParamsNNParams> {
     private static final Logger LOG = LoggerFactory.getLogger(NNMaster.class);

Global master NN parameters instance which is used to update model weights by using accumulated gradients.
     private NNParams globalNNParams = new NNParams();

Whether some configurations are initialized
     private AtomicBoolean isInitialized = new AtomicBoolean(false);

To calculate weights according to last weights and accumulated gradients
     private Weight weightCalculator = null;
     private double learningRate;
     public NNParams compute(MasterContext<NNParamsNNParamscontext) {
         // For first step, we not only initialize whole context but also return weights to master to make sure all
         // workers and master are using the same weights.
         if(this..compareAndSet(falsetrue)) {
             // first iteration is used to set initial weights
             NNParams params = initWeights(context);
             // should be set here to make sure master and workers use the same weights
             return params;
         if(context.getWorkerResults() == null) {
             throw new IllegalArgumentException("workers' results are null.");
         double totalTestError = 0;
         double totalTrainError = 0;
         int size = 0;
         // before accumulate, reset gradients and train size
         for(NNParams nncontext.getWorkerResults()) {
             totalTestError += nn.getTestError();
             totalTrainError += nn.getTrainError();
         // worker result size is 0. throw exception because shouldn't happen
         if(size == 0) {
             throw new IllegalArgumentException("workers' results are empty.");
        // initialize weightCalCulater.
        if(this. == null) {
            // get the learning rate
            this. = new Weight(this..getGradients().length,
                    this..getTrainSize(), this..);
        // use last weights and current gradients to calculate
        double[] weights = this..calculateWeights(this..getWeights(),
        double currentTestError = totalTestError / size;
        double currentTrainError = totalTrainError / size;
        .info("NNMaster compute iteration {} ( avg train error {}, avg validation error {} )"new Object[] {
                context.getCurrentIteration(), currentTrainErrorcurrentTestError });
        NNParams params = new NNParams();
        // prevent null point
        params.setGradients(new double[0]);
        .debug("master result {} in iteration {}"paramscontext.getCurrentIteration());
        return params;
    private NNParams initWeights(MasterContext<NNParamsNNParamscontext) {
        int inputs = NumberFormatUtils.getInt(context.getProps().getProperty(.),
        int hiddens = NumberFormatUtils.getInt(context.getProps().getProperty(.),
        int outputs = NumberFormatUtils.getInt(context.getProps().getProperty(.),
        this. = NumberFormatUtils.getDouble(context.getProps().getProperty(
        BasicNetwork network = NNUtils.generateNetwork(inputshiddensoutputs);
        NNParams params = new NNParams();
        // prevent null point
        params.setGradients(new double[0]);
        return params;
New to GrepCode? Check out our FAQ X