Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
 /*
  * Copyright [2013-2014] eBay Software Foundation
  *  
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
  *    http://www.apache.org/licenses/LICENSE-2.0
  *  
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.guagua.mapreduce.example.lnr;
LinearRegressionMaster defines logic to update global linear regression model.

At first iteration, master builds a random model then send to all workers to start computing. This is to make all workers use the same model at the starting time.

At other iterations, master works:

  • 1. Accumulate all gradients from workers.
  • 2. Update global models by using accumulated gradients.
  • 3. Send new global model to workers by returning model parameters.
    private static final Logger LOG = LoggerFactory.getLogger(LinearRegressionMaster.class);
    private static final Random RANDOM = new Random();
    private int inputNum;
    private double[] weights;
    private double learnRate;
    private void init(MasterContext<LinearRegressionParamsLinearRegressionParamscontext) {
        this. = NumberFormatUtils.getInt(.,
                .);
        this. = NumberFormatUtils.getDouble(.,
                .);
    }
    @Override
        if(context.isFirstIteration()) {
            init(context);
             = new double[this. + 1];
            for(int i = 0; i < .i++) {
                [i] = .nextDouble();
            }
        } else {
            double[] gradients = new double[this. + 1];
            double sumError = 0.0d;
            int size = 0;
            for(LinearRegressionParams paramcontext.getWorkerResults()) {
                if(param != null) {
                    for(int i = 0; i < gradients.lengthi++) {
                        gradients[i] += param.getParameters()[i];
                    }
                }
                sumError += param.getError();
                size++;
            }
            for(int i = 0; i < .i++) {
                [i] -=  * gradients[i];
            }
            .info("DEBUG: Weights: {}", Arrays.toString(this.));
            .info("Iteration {} with error {}"context.getCurrentIteration(), sumError / size);
        }
        return new LinearRegressionParams();
    }
New to GrepCode? Check out our FAQ X