Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Copyright [2013-2014] eBay Software Foundation
   *  
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *    http://www.apache.org/licenses/LICENSE-2.0
   *  
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package ml.shifu.guagua.mapreduce.example.lnr;
 
 import java.util.List;
 
 
 import  org.apache.hadoop.io.LongWritable;
 import  org.apache.hadoop.io.Text;
 
LinearRegressionWorker defines logic to accumulate local linear regression gradients.

At first iteration, wait for master to use the consistent initiating model.

At other iterations, workers include:

  • 1. Update local model by using global model from last step..
  • 2. Accumulate gradients by using local worker input data.
  • 3. Send new local gradients to master by returning parameters.

WARNING: Input data should be normalized before, or you will get a very bad model.

 
 public class LinearRegressionWorker
         extends
 
     private static final Logger LOG = LoggerFactory.getLogger(LinearRegressionWorker.class);

    
Input column number
 
     private int inputNum;

    
Output column number
 
     private int outputNum;

    
In-memory data which located in memory at the first iteration.
 
     private List<DatadataList;

    
Local linear regression model.
 
     private double[] weights;

    
A splitter to split data with specified delimiter.
 
     private Splitter splitter = Splitter.on(",");
 
     @Override
     public void initRecordReader(GuaguaFileSplit fileSplitthrows IOException {
         this.setRecordReader(new GuaguaLineRecordReader(fileSplit));
     }
 
     @Override
     public void init(WorkerContext<LinearRegressionParamsLinearRegressionParamscontext) {
         this. = NumberFormatUtils.getInt(.,
                 .);
         this. = 1;
         this. = new LinkedList<Data>();
     }
 
     @Override
         if(context.isFirstIteration()) {
            return new LinearRegressionParams();
        } else {
            this. = context.getLastMasterResult().getParameters();
            double[] gradients = new double[this. + 1];
            double finalError = 0.0d;
            int size = 0;
            for(Data data) {
                double error = dot(data.inputsthis.) - data.outputs[0];
                finalError += error * error / 2;
                for(int i = 0; i < gradients.lengthi++) {
                    gradients[i] += error * data.inputs[i];
                }
                size++;
            }
            .info("Iteration {} with error {}"context.getCurrentIteration(), finalError / size);
            return new LinearRegressionParams(gradientsfinalError / size);
        }
    }

    
Compute dot value of two vectors.
    private double dot(double[] inputsdouble[] weights) {
        double value = 0.0d;
        for(int i = 0; i < weights.lengthi++) {
            value += weights[i] * inputs[i];
        }
        return value;
    }
    @Override
    public void load(GuaguaWritableAdapter<LongWritable> currentKeyGuaguaWritableAdapter<Text> currentValue,
            WorkerContext<LinearRegressionParamsLinearRegressionParamscontext) {
        String line = currentValue.getWritable().toString();
        double[] inputData = new double[ + 1];
        double[] outputData = new double[];
        int count = 0, inputIndex = 0, outputIndex = 0;
        inputData[inputIndex++] = 1.0d;
        for(String unit.split(line)) {
            if(count < ) {
                inputData[inputIndex++] = Double.valueOf(unit);
            } else if(count >=  && count < ( + )) {
                outputData[outputIndex++] = Double.valueOf(unit);
            } else {
                break;
            }
            count++;
        }
        this..add(new Data(inputDataoutputData));
    }
    private static class Data {
        public Data(double[] inputsdouble[] outputs) {
            this. = inputs;
            this. = outputs;
        }
        private final double[] inputs;
        private final double[] outputs;
    }
New to GrepCode? Check out our FAQ X