Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.operator.aggregation;
 
 
 
 import java.util.List;
 
 import static com.facebook.presto.byteCode.Access.FINAL;
 import static com.facebook.presto.byteCode.Access.PRIVATE;
 import static com.facebook.presto.byteCode.Access.PUBLIC;
 import static com.facebook.presto.byteCode.Access.a;
 import static com.facebook.presto.byteCode.NamedParameterDefinition.arg;
 import static com.facebook.presto.byteCode.OpCode.NOP;
 import static com.facebook.presto.byteCode.ParameterizedType.type;
 import static com.facebook.presto.byteCode.control.IfStatement.IfStatementBuilder;
 import static com.facebook.presto.byteCode.control.IfStatement.ifStatementBuilder;
 import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata;
 import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL;
 import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_INPUT_CHANNEL;
 import static com.facebook.presto.operator.aggregation.AggregationMetadata.countInputChannels;
 import static com.facebook.presto.sql.gen.Bootstrap.BOOTSTRAP_METHOD;
 import static com.facebook.presto.sql.gen.CompilerUtils.defineClass;
 import static com.facebook.presto.sql.gen.CompilerUtils.makeClassName;
 import static com.facebook.presto.sql.gen.SqlTypeByteCodeExpression.constantType;
 import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
 
 public class AccumulatorCompiler
 {
     {
         Class<? extends AccumulatoraccumulatorClass = generateAccumulatorClass(
                 Accumulator.class,
                 metadata,
                 classLoader);
 
         Class<? extends GroupedAccumulatorgroupedAccumulatorClass = generateAccumulatorClass(
                 GroupedAccumulator.class,
                 metadata,
                 classLoader);
 
         return new GenericAccumulatorFactoryBinder(
                 metadata.getStateSerializer(),
                 metadata.getStateFactory(),
                 accumulatorClass,
                 groupedAccumulatorClass,
                 metadata.isApproximate());
     }
 
     private static <T> Class<? extends T> generateAccumulatorClass(
             Class<T> accumulatorInterface,
             AggregationMetadata metadata,
             DynamicClassLoader classLoader)
     {
         boolean grouped = accumulatorInterface == GroupedAccumulator.class;
         boolean approximate = metadata.isApproximate();
 
         ClassDefinition definition = new ClassDefinition(
                 a(),
                 makeClassName(metadata.getName() + accumulatorInterface.getSimpleName()),
                type(Object.class),
                type(accumulatorInterface));
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        AccumulatorStateSerializer<?> stateSerializer = metadata.getStateSerializer();
        AccumulatorStateFactory<?> stateFactory = metadata.getStateFactory();
        FieldDefinition stateSerializerField = definition.declareField(a(), "stateSerializer"AccumulatorStateSerializer.class);
        FieldDefinition stateFactoryField = definition.declareField(a(), "stateFactory"AccumulatorStateFactory.class);
        FieldDefinition inputChannelsField = definition.declareField(a(), "inputChannels"type(List.classInteger.class));
        FieldDefinition maskChannelField = definition.declareField(a(), "maskChannel"type(Optional.classInteger.class));
        FieldDefinition sampleWeightChannelField = null;
        FieldDefinition confidenceField = null;
        if (approximate) {
            sampleWeightChannelField = definition.declareField(a(), "sampleWeightChannel"type(Optional.classInteger.class));
            confidenceField = definition.declareField(a(), "confidence"double.class);
        }
        FieldDefinition stateField = definition.declareField(a(), "state"grouped ? stateFactory.getGroupedStateClass() : stateFactory.getSingleStateClass());
        // Generate constructor
        generateConstructor(
                definition,
                stateSerializerField,
                stateFactoryField,
                inputChannelsField,
                maskChannelField,
                sampleWeightChannelField,
                confidenceField,
                stateField,
                grouped);
        // Generate methods
        generateAddInput(definitionstateFieldinputChannelsFieldmaskChannelFieldsampleWeightChannelFieldmetadata.getInputMetadata(), metadata.getInputFunction(), callSiteBindergrouped);
        generateGetEstimatedSize(definitionstateField);
        MethodDefinition getIntermediateType = generateGetIntermediateType(definitioncallSiteBinderstateSerializer.getSerializedType());
        MethodDefinition getFinalType = generateGetFinalType(definitioncallSiteBindermetadata.getOutputType());
        if (metadata.getIntermediateInputFunction() == null) {
            generateAddIntermediateAsCombine(definitionstateFieldstateSerializerFieldstateFactoryFieldmetadata.getCombineFunction(), stateFactory.getSingleStateClass(), grouped);
        }
        else {
            generateAddIntermediateAsIntermediateInput(definitionstateFieldmetadata.getIntermediateInputMetadata(), metadata.getIntermediateInputFunction(), callSiteBindergrouped);
        }
        if (grouped) {
            generateGroupedEvaluateIntermediate(definitionstateSerializerFieldstateField);
        }
        else {
            generateEvaluateIntermediate(definitiongetIntermediateTypestateSerializerFieldstateField);
        }
        if (grouped) {
            generateGroupedEvaluateFinal(definitionconfidenceFieldstateSerializerFieldstateFieldmetadata.getOutputFunction(), metadata.isApproximate());
        }
        else {
            generateEvaluateFinal(definitiongetFinalTypeconfidenceFieldstateSerializerFieldstateFieldmetadata.getOutputFunction(), metadata.isApproximate());
        }
        return defineClass(definitionaccumulatorInterfacecallSiteBinder.getBindings(), classLoader);
    }
    private static MethodDefinition generateGetIntermediateType(ClassDefinition definitionCallSiteBinder callSiteBinderType type)
    {
        MethodDefinition methodDefinition = definition.declareMethod(a(), "getIntermediateType"type(Type.class));
        methodDefinition.getBody()
                .append(constantType(new CompilerContext(), callSiteBindertype))
                .retObject();
        return methodDefinition;
    }
    private static MethodDefinition generateGetFinalType(ClassDefinition definitionCallSiteBinder callSiteBinderType type)
    {
        MethodDefinition methodDefinition = definition.declareMethod(a(), "getFinalType"type(Type.class));
        methodDefinition.getBody()
                .append(constantType(new CompilerContext(), callSiteBindertype))
                .retObject();
        return methodDefinition;
    }
    private static void generateGetEstimatedSize(ClassDefinition definitionFieldDefinition stateField)
    {
        definition.declareMethod(a(), "getEstimatedSize"type(long.class))
                .getBody()
                .pushThis()
                .getField(stateField)
                .invokeVirtual(stateField.getType(), "getEstimatedSize"type(long.class))
                .retLong();
    }
    private static void generateAddInput(
            ClassDefinition definition,
            FieldDefinition stateField,
            FieldDefinition inputChannelsField,
            FieldDefinition maskChannelField,
            @Nullable FieldDefinition sampleWeightChannelField,
            List<ParameterMetadataparameterMetadatas,
            Method inputFunction,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        CompilerContext context = new CompilerContext();
        ImmutableList.Builder<NamedParameterDefinitionparameters = ImmutableList.builder();
        if (grouped) {
            parameters.add(arg("groupIdsBlock"GroupByIdBlock.class));
        }
        parameters.add(arg("page"Page.class));
        Block body = definition.declareMethod(contexta(), "addInput"type(void.class), parameters.build())
                .getBody();
        if (grouped) {
            generateEnsureCapacity(stateFieldbody);
        }
        List<VariableparameterVariables = new ArrayList<>();
        for (int i = 0; i < countInputChannels(parameterMetadatas); i++) {
            parameterVariables.add(context.declareVariable(com.facebook.presto.spi.block.Block.class"block" + i));
        }
        Variable masksBlock = context.declareVariable(com.facebook.presto.spi.block.Block.class"masksBlock");
        Variable sampleWeightsBlock = null;
        if (sampleWeightChannelField != null) {
            sampleWeightsBlock = context.declareVariable(com.facebook.presto.spi.block.Block.class"sampleWeightsBlock");
        }
        body.comment("masksBlock = maskChannel.map(page.blockGetter()).orElse(null);")
                .pushThis()
                .getField(maskChannelField)
                .getVariable("page")
                .invokeStatic(type(AggregationUtils.class), "pageBlockGetter"type(Function.classInteger.classcom.facebook.presto.spi.block.Block.class), type(Page.class))
                .invokeVirtual(Optional.class"map"Optional.classFunction.class)
                .pushNull()
                .invokeVirtual(Optional.class"orElse"Object.classObject.class)
                .checkCast(com.facebook.presto.spi.block.Block.class)
                .putVariable(masksBlock);
        if (sampleWeightChannelField != null) {
            body.comment("sampleWeightsBlock = sampleWeightChannel.map(page.blockGetter()).get();")
                    .pushThis()
                    .getField(sampleWeightChannelField)
                    .getVariable("page")
                    .invokeStatic(type(AggregationUtils.class), "pageBlockGetter"type(Function.classInteger.classcom.facebook.presto.spi.block.Block.class), type(Page.class))
                    .invokeVirtual(Optional.class"map"Optional.classFunction.class)
                    .invokeVirtual(Optional.class"get"Object.class)
                    .checkCast(com.facebook.presto.spi.block.Block.class)
                    .putVariable(sampleWeightsBlock);
        }
        // Get all parameter blocks
        for (int i = 0; i < countInputChannels(parameterMetadatas); i++) {
            body.comment("%s = page.getBlock(inputChannels.get(%d));"parameterVariables.get(i).getName(), i)
                    .getVariable("page")
                    .pushThis()
                    .getField(inputChannelsField)
                    .push(i)
                    .invokeInterface(List.class"get"Object.classint.class)
                    .checkCast(Integer.class)
                    .invokeVirtual(Integer.class"intValue"int.class)
                    .invokeVirtual(Page.class"getBlock"com.facebook.presto.spi.block.Block.classint.class)
                    .putVariable(parameterVariables.get(i));
        }
        Block block = generateInputForLoop(stateFieldparameterMetadatasinputFunctioncontextparameterVariablesmasksBlocksampleWeightsBlockcallSiteBindergrouped);
        body.append(block)
                .ret();
    }
    private static Block generateInputForLoop(
            FieldDefinition stateField,
            List<ParameterMetadataparameterMetadatas,
            Method inputFunction,
            CompilerContext context,
            List<VariableparameterVariables,
            Variable masksBlock,
            @Nullable Variable sampleWeightsBlock,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        // For-loop over rows
        Variable positionVariable = context.declareVariable(int.class"position");
        Variable sampleWeightVariable = null;
        if (sampleWeightsBlock != null) {
            sampleWeightVariable = context.declareVariable(long.class"sampleWeight");
        }
        Variable rowsVariable = context.declareVariable(int.class"rows");
        Block block = new Block(context)
                .getVariable("page")
                .invokeVirtual(Page.class"getPositionCount"int.class)
                .putVariable(rowsVariable)
                .initializeVariable(positionVariable);
        if (sampleWeightVariable != null) {
            block.initializeVariable(sampleWeightVariable);
        }
        Block loopBody = generateInvokeInputFunction(contextstateFieldpositionVariablesampleWeightVariableparameterVariablesparameterMetadatasinputFunctioncallSiteBindergrouped);
        //  Wrap with null checks
        List<Booleannullable = new ArrayList<>();
        for (ParameterMetadata metadata : parameterMetadatas) {
            if (metadata.getParameterType() == ) {
                nullable.add(false);
            }
            else if (metadata.getParameterType() == ) {
                nullable.add(true);
            }
        }
        checkState(nullable.size() == parameterVariables.size(), "Number of parameters does not match");
        for (int i = 0; i < parameterVariables.size(); i++) {
            if (!nullable.get(i)) {
                IfStatementBuilder builder = ifStatementBuilder(context);
                Variable variableDefinition = parameterVariables.get(i);
                builder.comment("if(!%s.isNull(position))"variableDefinition.getName())
                        .condition(new Block(context)
                                .getVariable(variableDefinition)
                                .getVariable(positionVariable)
                                .invokeInterface(com.facebook.presto.spi.block.Block.class"isNull"boolean.classint.class))
                        .ifTrue()
                        .ifFalse(loopBody);
                loopBody = new Block(context).append(builder.build());
            }
        }
        // Check that sample weight is > 0 (also checks the mask)
        if (sampleWeightVariable != null) {
            loopBody = generateComputeSampleWeightAndCheckGreaterThanZero(contextloopBodysampleWeightVariablemasksBlocksampleWeightsBlockpositionVariable);
        }
        // Otherwise just check the mask
        else {
            IfStatementBuilder builder = ifStatementBuilder(context);
            builder.comment("if(testMask(%s, position))"masksBlock.getName())
                    .condition(new Block(context)
                            .getVariable(masksBlock)
                            .getVariable(positionVariable)
                            .invokeStatic(CompilerOperations.class"testMask"boolean.classcom.facebook.presto.spi.block.Block.classint.class))
                    .ifTrue(loopBody)
                    .ifFalse();
            loopBody = new Block(context).append(builder.build());
        }
        block.append(new ForLoop.ForLoopBuilder(context)
                .initialize(new Block(context).putVariable(positionVariable, 0))
                .condition(new Block(context)
                        .getVariable(positionVariable)
                        .getVariable(rowsVariable)
                        .invokeStatic(CompilerOperations.class"lessThan"boolean.classint.classint.class))
                .update(new Block(context).incrementVariable(positionVariable, (byte) 1))
                .body(loopBody)
                .build());
        return block;
    }
    private static Block generateComputeSampleWeightAndCheckGreaterThanZero(CompilerContext contextBlock bodyVariable sampleWeightVariable masksVariable sampleWeightsVariable position)
    {
        Block block = new Block(context)
                .comment("sampleWeight = computeSampleWeight(masks, sampleWeights, position);")
                .getVariable(masks)
                .getVariable(sampleWeights)
                .getVariable(position)
                .invokeStatic(ApproximateUtils.class"computeSampleWeight"long.classcom.facebook.presto.spi.block.Block.classcom.facebook.presto.spi.block.Block.classint.class)
                .putVariable(sampleWeight);
        IfStatementBuilder builder = ifStatementBuilder(context);
        builder.comment("if(sampleWeight > 0)")
                .condition(new Block(context)
                        .getVariable(sampleWeight)
                        .invokeStatic(CompilerOperations.class"longGreaterThanZero"boolean.classlong.class))
                .ifTrue(body)
                .ifFalse();
        return block.append(builder.build());
    }
    private static Block generateInvokeInputFunction(
            CompilerContext context,
            FieldDefinition stateField,
            Variable position,
            @Nullable Variable sampleWeight,
            List<VariableparameterVariables,
            List<ParameterMetadataparameterMetadatas,
            Method inputFunction,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        Block block = new Block(context);
        if (grouped) {
            generateSetGroupIdFromGroupIdsBlock(stateFieldpositionblock);
        }
        block.comment("Call input function with unpacked Block arguments");
        Class<?>[] parameters = inputFunction.getParameterTypes();
        int inputChannel = 0;
        for (int i = 0; i < parameters.lengthi++) {
            ParameterMetadata parameterMetadata = parameterMetadatas.get(i);
            switch (parameterMetadata.getParameterType()) {
                case :
                    block.pushThis().getField(stateField);
                    break;
                case :
                    block.getVariable(position);
                    break;
                case :
                    checkNotNull(sampleWeight"sampleWeight is null");
                    block.getVariable(sampleWeight);
                    break;
                case :
                    block.getVariable(parameterVariables.get(inputChannel));
                    inputChannel++;
                    break;
                case :
                    Block getBlockByteCode = new Block(context)
                            .getVariable(parameterVariables.get(inputChannel));
                    pushStackType(blockparameterMetadata.getSqlType(), getBlockByteCodeparameters[i], callSiteBinder);
                    inputChannel++;
                    break;
                default:
                    throw new IllegalArgumentException("Unsupported parameter type: " + parameterMetadata.getParameterType());
            }
        }
        block.invokeStatic(inputFunction);
        return block;
    }
    // Assumes that there is a variable named 'position' in the block, which is the current index
    private static void pushStackType(Block blockType sqlTypeBlock getBlockByteCodeClass<?> parameterCallSiteBinder callSiteBinder)
    {
        if (parameter == com.facebook.presto.spi.block.Block.class) {
            block.append(getBlockByteCode);
        }
        else if (parameter == long.class) {
            block.comment("%s.getLong(block, position)"sqlType.getTypeSignature())
                    .append(SqlTypeByteCodeExpression.constantType(new CompilerContext(), callSiteBindersqlType))
                    .append(getBlockByteCode)
                    .getVariable("position")
                    .invokeInterface(Type.class"getLong"long.classcom.facebook.presto.spi.block.Block.classint.class);
        }
        else if (parameter == double.class) {
            block.comment("%s.getDouble(block, position)"sqlType.getTypeSignature())
                    .append(SqlTypeByteCodeExpression.constantType(new CompilerContext(), callSiteBindersqlType))
                    .append(getBlockByteCode)
                    .getVariable("position")
                    .invokeInterface(Type.class"getDouble"double.classcom.facebook.presto.spi.block.Block.classint.class);
        }
        else if (parameter == boolean.class) {
            block.comment("%s.getBoolean(block, position)"sqlType.getTypeSignature())
                    .append(SqlTypeByteCodeExpression.constantType(new CompilerContext(), callSiteBindersqlType))
                    .append(getBlockByteCode)
                    .getVariable("position")
                    .invokeInterface(Type.class"getBoolean"boolean.classcom.facebook.presto.spi.block.Block.classint.class);
        }
        else if (parameter == Slice.class) {
            block.comment("%s.getBoolean(block, position)"sqlType.getTypeSignature())
                    .append(SqlTypeByteCodeExpression.constantType(new CompilerContext(), callSiteBindersqlType))
                    .append(getBlockByteCode)
                    .getVariable("position")
                    .invokeInterface(Type.class"getSlice"Slice.classcom.facebook.presto.spi.block.Block.classint.class);
        }
        else {
            throw new IllegalArgumentException("Unsupported parameter type: " + parameter.getSimpleName());
        }
    }
    private static void generateAddIntermediateAsCombine(
            ClassDefinition definition,
            FieldDefinition stateField,
            FieldDefinition stateSerializerField,
            FieldDefinition stateFactoryField,
            Method combineFunction,
            Class<?> singleStateClass,
            boolean grouped)
    {
        CompilerContext context = new CompilerContext();
        Block body = declareAddIntermediate(definitiongroupedcontext);
        Variable scratchStateVariable = context.declareVariable(singleStateClass"scratchState");
        Variable positionVariable = context.declareVariable(int.class"position");
        body.comment("scratchState = stateFactory.createSingleState();")
                .pushThis()
                .getField(stateFactoryField)
                .invokeInterface(AccumulatorStateFactory.class"createSingleState"Object.class)
                .checkCast(scratchStateVariable.getType())
                .putVariable(scratchStateVariable);
        if (grouped) {
            generateEnsureCapacity(stateFieldbody);
        }
        Block loopBody = new Block(context);
        if (grouped) {
            generateSetGroupIdFromGroupIdsBlock(stateFieldpositionVariableloopBody);
        }
        loopBody.comment("stateSerializer.deserialize(block, position, scratchState)")
                .pushThis()
                .getField(stateSerializerField)
                .getVariable("block")
                .getVariable(positionVariable)
                .getVariable(scratchStateVariable)
                .invokeInterface(AccumulatorStateSerializer.class"deserialize"void.classcom.facebook.presto.spi.block.Block.classint.classObject.class);
        loopBody.comment("combine(state, scratchState)")
                .pushThis()
                .getField(stateField)
                .getVariable("scratchState")
                .invokeStatic(combineFunction);
        body.append(generateBlockNonNullPositionForLoop(contextpositionVariableloopBody))
                .ret();
    }
    private static void generateSetGroupIdFromGroupIdsBlock(FieldDefinition stateFieldVariable positionVariableBlock block)
    {
        block.comment("state.setGroupId(groupIdsBlock.getGroupId(position))")
                .pushThis()
                .getField(stateField)
                .getVariable("groupIdsBlock")
                .getVariable(positionVariable)
                .invokeVirtual(GroupByIdBlock.class"getGroupId"long.classint.class)
                .invokeVirtual(stateField.getType(), "setGroupId"type(void.class), type(long.class));
    }
    private static void generateEnsureCapacity(FieldDefinition stateFieldBlock block)
    {
        block.comment("state.ensureCapacity(groupIdsBlock.getGroupCount())")
                .pushThis()
                .getField(stateField)
                .getVariable("groupIdsBlock")
                .invokeVirtual(GroupByIdBlock.class"getGroupCount"long.class)
                .invokeVirtual(stateField.getType(), "ensureCapacity"type(void.class), type(long.class));
    }
    private static Block declareAddIntermediate(ClassDefinition definitionboolean groupedCompilerContext context)
    {
        ImmutableList.Builder<NamedParameterDefinitionparameters = ImmutableList.builder();
        if (grouped) {
            parameters.add(arg("groupIdsBlock"GroupByIdBlock.class));
        }
        parameters.add(arg("block"com.facebook.presto.spi.block.Block.class));
        return definition.declareMethod(
                context,
                a(),
                "addIntermediate",
                type(void.class),
                parameters.build())
                .getBody();
    }
    private static void generateAddIntermediateAsIntermediateInput(
            ClassDefinition definition,
            FieldDefinition stateField,
            List<ParameterMetadataparameterMetadatas,
            Method intermediateInputFunction,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        CompilerContext context = new CompilerContext();
        Block body = declareAddIntermediate(definitiongroupedcontext);
        if (grouped) {
            generateEnsureCapacity(stateFieldbody);
        }
        Variable positionVariable = context.declareVariable(int.class"position");
        Block loopBody = generateInvokeInputFunction(contextstateFieldpositionVariablenull, ImmutableList.of(context.getVariable("block")), parameterMetadatasintermediateInputFunctioncallSiteBindergrouped);
        body.append(generateBlockNonNullPositionForLoop(contextpositionVariableloopBody))
                .ret();
    }
    // Generates a for-loop with a local variable named "position" defined, with the current position in the block,
    // loopBody will only be executed for non-null positions in the Block
    private static Block generateBlockNonNullPositionForLoop(CompilerContext contextVariable positionVariableBlock loopBody)
    {
        Variable rowsVariable = context.declareVariable(int.class"rows");
        Block block = new Block(context)
                .getVariable("block")
                .invokeInterface(com.facebook.presto.spi.block.Block.class"getPositionCount"int.class)
                .putVariable(rowsVariable);
        IfStatementBuilder builder = ifStatementBuilder(context);
        builder.comment("if(!block.isNull(position))")
                .condition(new Block(context)
                        .getVariable("block")
                        .getVariable(positionVariable)
                        .invokeInterface(com.facebook.presto.spi.block.Block.class"isNull"boolean.classint.class))
                .ifTrue()
                .ifFalse(loopBody);
        block.append(new ForLoop.ForLoopBuilder(context)
                .initialize(new Block(context).putVariable(positionVariable, 0))
                .condition(new Block(context)
                        .getVariable(positionVariable)
                        .getVariable(rowsVariable)
                        .invokeStatic(CompilerOperations.class"lessThan"boolean.classint.classint.class))
                .update(new Block(context).incrementVariable(positionVariable, (byte) 1))
                .body(builder.build())
                .build());
        return block;
    }
    private static void generateGroupedEvaluateIntermediate(ClassDefinition definitionFieldDefinition stateSerializerFieldFieldDefinition stateField)
    {
        definition.declareMethod(
                a(),
                "evaluateIntermediate",
                type(void.class),
                arg("groupId"int.class),
                arg("out"BlockBuilder.class))
                .getBody()
                .comment("state.setGroupId(groupId)")
                .pushThis()
                .getField(stateField)
                .getVariable("groupId")
                .intToLong()
                .invokeVirtual(stateField.getType(), "setGroupId"type(void.class), type(long.class))
                .comment("stateSerializer.serialize(state, out)")
                .pushThis()
                .getField(stateSerializerField)
                .pushThis()
                .getField(stateField)
                .getVariable("out")
                .invokeInterface(AccumulatorStateSerializer.class"serialize"void.classObject.classBlockBuilder.class)
                .ret();
    }
    private static void generateEvaluateIntermediate(ClassDefinition definitionMethodDefinition getIntermediateTypeFieldDefinition stateSerializerFieldFieldDefinition stateField)
    {
        CompilerContext context = new CompilerContext();
        definition.declareMethod(
                context,
                a(),
                "evaluateIntermediate",
                type(void.class),
                arg("out"BlockBuilder.class))
                .getBody()
                .comment("stateSerializer.serialize(state, out)")
                .pushThis()
                .getField(stateSerializerField)
                .pushThis()
                .getField(stateField)
                .getVariable("out")
                .invokeInterface(AccumulatorStateSerializer.class"serialize"void.classObject.classBlockBuilder.class)
                .ret();
    }
    private static void generateGroupedEvaluateFinal(
            ClassDefinition definition,
            FieldDefinition confidenceField,
            FieldDefinition stateSerializerField,
            FieldDefinition stateField,
            @Nullable Method outputFunction,
            boolean approximate)
    {
        Block body = definition.declareMethod(
                a(),
                "evaluateFinal",
                type(void.class),
                arg("groupId"int.class),
                arg("out"BlockBuilder.class))
                .getBody()
                .comment("state.setGroupId(groupId)")
                .pushThis()
                .getField(stateField)
                .getVariable("groupId")
                .intToLong()
                .invokeVirtual(stateField.getType(), "setGroupId"type(void.class), type(long.class));
        if (outputFunction != null) {
            body.comment("output(state, out)")
                    .pushThis()
                    .getField(stateField);
            if (approximate) {
                checkNotNull(confidenceField"confidenceField is null");
                body.pushThis().getField(confidenceField);
            }
            body.getVariable("out")
                    .invokeStatic(outputFunction);
        }
        else {
            checkArgument(!approximate"Approximate aggregations must specify an output function");
            body.comment("stateSerializer.serialize(state, out)")
                    .pushThis()
                    .getField(stateSerializerField)
                    .pushThis()
                    .getField(stateField)
                    .getVariable("out")
                    .invokeInterface(AccumulatorStateSerializer.class"serialize"void.classObject.classBlockBuilder.class);
        }
        body.ret();
    }
    private static void generateEvaluateFinal(
            ClassDefinition definition,
            MethodDefinition getFinalType,
            FieldDefinition confidenceField,
            FieldDefinition stateSerializerField,
            FieldDefinition stateField,
            @Nullable
            Method outputFunction,
            boolean approximate)
    {
        Block body = definition.declareMethod(
                a(),
                "evaluateFinal",
                type(void.class),
                arg("out"BlockBuilder.class))
                .getBody();
        if (outputFunction != null) {
            body.comment("output(state, out)")
                    .pushThis()
                    .getField(stateField);
            if (approximate) {
                checkNotNull(confidenceField"confidenceField is null");
                body.pushThis().getField(confidenceField);
            }
            body.getVariable("out")
                    .invokeStatic(outputFunction);
        }
        else {
            checkArgument(!approximate"Approximate aggregations must specify an output function");
            body.comment("stateSerializer.serialize(state, out)")
                    .pushThis()
                    .getField(stateSerializerField)
                    .pushThis()
                    .getField(stateField)
                    .getVariable("out")
                    .invokeInterface(AccumulatorStateSerializer.class"serialize"void.classObject.classBlockBuilder.class);
        }
        body.ret();
    }
    private static void generateConstructor(
            ClassDefinition definition,
            FieldDefinition stateSerializerField,
            FieldDefinition stateFactoryField,
            FieldDefinition inputChannelsField,
            FieldDefinition maskChannelField,
            @Nullable FieldDefinition sampleWeightChannelField,
            @Nullable FieldDefinition confidenceField,
            FieldDefinition stateField,
            boolean grouped)
    {
        Block body = definition.declareConstructor(
                a(),
                arg("stateSerializer"AccumulatorStateSerializer.class),
                arg("stateFactory"AccumulatorStateFactory.class),
                arg("inputChannels"type(List.classInteger.class)),
                arg("maskChannel"type(Optional.classInteger.class)),
                arg("sampleWeightChannel"type(Optional.classInteger.class)),
                arg("confidence"double.class))
                .getBody()
                .comment("super();")
                .pushThis()
                .invokeConstructor(Object.class);
        generateCastCheckNotNullAndAssign(bodystateSerializerField"stateSerializer");
        generateCastCheckNotNullAndAssign(bodystateFactoryField"stateFactory");
        generateCastCheckNotNullAndAssign(bodyinputChannelsField"inputChannels");
        generateCastCheckNotNullAndAssign(bodymaskChannelField"maskChannel");
        if (sampleWeightChannelField != null) {
            generateCastCheckNotNullAndAssign(bodysampleWeightChannelField"sampleWeightChannel");
        }
        String createState;
        if (grouped) {
            createState = "createGroupedState";
        }
        else {
            createState = "createSingleState";
        }
        if (confidenceField != null) {
            body.comment("this.confidence = confidence")
                    .pushThis()
                    .getVariable("confidence")
                    .putField(confidenceField);
        }
        body.comment("this.state = stateFactory.%s()"createState)
                .pushThis()
                .getVariable("stateFactory")
                .invokeInterface(AccumulatorStateFactory.classcreateStateObject.class)
                .checkCast(stateField.getType())
                .putField(stateField)
                .ret();
    }
    private static void generateCastCheckNotNullAndAssign(Block blockFieldDefinition fieldString variableName)
    {
        block.comment("this.%s = checkNotNull(%s, \"%s is null\""field.getName(), variableNamevariableName)
                .pushThis()
                .getVariable(variableName)
                .checkCast(field.getType())
                .push(variableName + " is null")
                .invokeStatic(Preconditions.class"checkNotNull"Object.classObject.classObject.class)
                .checkCast(field.getType())
                .putField(field);
    }
New to GrepCode? Check out our FAQ X