Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.operator.aggregation;
 
 
 
 import java.util.List;
 
 import static com.facebook.presto.byteCode.Access.FINAL;
 import static com.facebook.presto.byteCode.Access.PRIVATE;
 import static com.facebook.presto.byteCode.Access.PUBLIC;
 import static com.facebook.presto.byteCode.Access.a;
 import static com.facebook.presto.byteCode.OpCode.NOP;
 import static com.facebook.presto.byteCode.Parameter.arg;
 import static com.facebook.presto.byteCode.ParameterizedType.type;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantInt;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantString;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.invokeStatic;
 import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata;
 import static com.facebook.presto.operator.aggregation.AggregationMetadata.countInputChannels;
 import static com.facebook.presto.sql.gen.ByteCodeUtils.invoke;
 import static com.facebook.presto.sql.gen.CompilerUtils.defineClass;
 import static com.facebook.presto.sql.gen.CompilerUtils.makeClassName;
 import static com.facebook.presto.sql.gen.SqlTypeByteCodeExpression.constantType;
 import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
 
 public class AccumulatorCompiler
 {
     {
         Class<? extends AccumulatoraccumulatorClass = generateAccumulatorClass(
                 Accumulator.class,
                 metadata,
                 classLoader);
 
         Class<? extends GroupedAccumulatorgroupedAccumulatorClass = generateAccumulatorClass(
                 GroupedAccumulator.class,
                 metadata,
                 classLoader);
 
         return new GenericAccumulatorFactoryBinder(
                 metadata.getStateSerializer(),
                 metadata.getStateFactory(),
                 accumulatorClass,
                 groupedAccumulatorClass,
                 metadata.isApproximate());
     }
 
     private static <T> Class<? extends T> generateAccumulatorClass(
             Class<T> accumulatorInterface,
             AggregationMetadata metadata,
             DynamicClassLoader classLoader)
     {
         boolean grouped = accumulatorInterface == GroupedAccumulator.class;
         boolean approximate = metadata.isApproximate();
 
         ClassDefinition definition = new ClassDefinition(
                 a(),
                makeClassName(metadata.getName() + accumulatorInterface.getSimpleName()),
                type(Object.class),
                type(accumulatorInterface));
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        AccumulatorStateSerializer<?> stateSerializer = metadata.getStateSerializer();
        AccumulatorStateFactory<?> stateFactory = metadata.getStateFactory();
        FieldDefinition stateSerializerField = definition.declareField(a(), "stateSerializer"AccumulatorStateSerializer.class);
        FieldDefinition stateFactoryField = definition.declareField(a(), "stateFactory"AccumulatorStateFactory.class);
        FieldDefinition inputChannelsField = definition.declareField(a(), "inputChannels"type(List.classInteger.class));
        FieldDefinition maskChannelField = definition.declareField(a(), "maskChannel"type(Optional.classInteger.class));
        FieldDefinition sampleWeightChannelField = null;
        FieldDefinition confidenceField = null;
        if (approximate) {
            sampleWeightChannelField = definition.declareField(a(), "sampleWeightChannel"type(Optional.classInteger.class));
            confidenceField = definition.declareField(a(), "confidence"double.class);
        }
        FieldDefinition stateField = definition.declareField(a(), "state"grouped ? stateFactory.getGroupedStateClass() : stateFactory.getSingleStateClass());
        // Generate constructor
        generateConstructor(
                definition,
                stateSerializerField,
                stateFactoryField,
                inputChannelsField,
                maskChannelField,
                sampleWeightChannelField,
                confidenceField,
                stateField,
                grouped);
        // Generate methods
        generateAddInput(definitionstateFieldinputChannelsFieldmaskChannelFieldsampleWeightChannelFieldmetadata.getInputMetadata(), metadata.getInputFunction(), callSiteBindergrouped);
        generateGetEstimatedSize(definitionstateField);
        generateGetIntermediateType(definitioncallSiteBinderstateSerializer.getSerializedType());
        generateGetFinalType(definitioncallSiteBindermetadata.getOutputType());
        if (metadata.getIntermediateInputFunction() == null) {
            generateAddIntermediateAsCombine(definitionstateFieldstateSerializerFieldstateFactoryFieldmetadata.getCombineFunction(), stateFactory.getSingleStateClass(), callSiteBindergrouped);
        }
        else {
            generateAddIntermediateAsIntermediateInput(definitionstateFieldmetadata.getIntermediateInputMetadata(), metadata.getIntermediateInputFunction(), callSiteBindergrouped);
        }
        if (grouped) {
            generateGroupedEvaluateIntermediate(definitionstateSerializerFieldstateField);
        }
        else {
            generateEvaluateIntermediate(definitionstateSerializerFieldstateField);
        }
        if (grouped) {
            generateGroupedEvaluateFinal(definitionconfidenceFieldstateSerializerFieldstateFieldmetadata.getOutputFunction(), metadata.isApproximate(), callSiteBinder);
        }
        else {
            generateEvaluateFinal(definitionconfidenceFieldstateSerializerFieldstateFieldmetadata.getOutputFunction(), metadata.isApproximate(), callSiteBinder);
        }
        return defineClass(definitionaccumulatorInterfacecallSiteBinder.getBindings(), classLoader);
    }
    private static MethodDefinition generateGetIntermediateType(ClassDefinition definitionCallSiteBinder callSiteBinderType type)
    {
        MethodDefinition methodDefinition = definition.declareMethod(a(), "getIntermediateType"type(Type.class));
        methodDefinition.getBody()
                .append(constantType(callSiteBindertype))
                .retObject();
        return methodDefinition;
    }
    private static MethodDefinition generateGetFinalType(ClassDefinition definitionCallSiteBinder callSiteBinderType type)
    {
        MethodDefinition methodDefinition = definition.declareMethod(a(), "getFinalType"type(Type.class));
        methodDefinition.getBody()
                .append(constantType(callSiteBindertype))
                .retObject();
        return methodDefinition;
    }
    private static void generateGetEstimatedSize(ClassDefinition definitionFieldDefinition stateField)
    {
        MethodDefinition method = definition.declareMethod(a(), "getEstimatedSize"type(long.class));
        ByteCodeExpression state = method.getThis().getField(stateField);
        method.getBody()
                .append(state.invoke("getEstimatedSize"long.class).ret());
    }
    private static void generateAddInput(
            ClassDefinition definition,
            FieldDefinition stateField,
            FieldDefinition inputChannelsField,
            FieldDefinition maskChannelField,
            @Nullable FieldDefinition sampleWeightChannelField,
            List<ParameterMetadataparameterMetadatas,
            MethodHandle inputFunction,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        ImmutableList.Builder<Parameterparameters = ImmutableList.builder();
        if (grouped) {
            parameters.add(arg("groupIdsBlock"GroupByIdBlock.class));
        }
        Parameter page = arg("page"Page.class);
        parameters.add(page);
        MethodDefinition method = definition.declareMethod(a(), "addInput"type(void.class), parameters.build());
        Scope scope = method.getScope();
        Block body = method.getBody();
        Variable thisVariable = method.getThis();
        if (grouped) {
            generateEnsureCapacity(scopestateFieldbody);
        }
        List<VariableparameterVariables = new ArrayList<>();
        for (int i = 0; i < countInputChannels(parameterMetadatas); i++) {
            parameterVariables.add(scope.declareVariable(com.facebook.presto.spi.block.Block.class"block" + i));
        }
        Variable masksBlock = scope.declareVariable(com.facebook.presto.spi.block.Block.class"masksBlock");
        Variable sampleWeightsBlock = null;
        if (sampleWeightChannelField != null) {
            sampleWeightsBlock = scope.declareVariable(com.facebook.presto.spi.block.Block.class"sampleWeightsBlock");
        }
        body.comment("masksBlock = maskChannel.map(page.blockGetter()).orElse(null);")
                .append(thisVariable.getField(maskChannelField))
                .append(page)
                .invokeStatic(type(AggregationUtils.class), "pageBlockGetter"type(Function.classInteger.classcom.facebook.presto.spi.block.Block.class), type(Page.class))
                .invokeVirtual(Optional.class"map"Optional.classFunction.class)
                .pushNull()
                .invokeVirtual(Optional.class"orElse"Object.classObject.class)
                .checkCast(com.facebook.presto.spi.block.Block.class)
                .putVariable(masksBlock);
        if (sampleWeightChannelField != null) {
            body.comment("sampleWeightsBlock = sampleWeightChannel.map(page.blockGetter()).get();")
                    .append(thisVariable.getField(sampleWeightChannelField))
                    .append(page)
                    .invokeStatic(type(AggregationUtils.class), "pageBlockGetter"type(Function.classInteger.classcom.facebook.presto.spi.block.Block.class), type(Page.class))
                    .invokeVirtual(Optional.class"map"Optional.classFunction.class)
                    .invokeVirtual(Optional.class"get"Object.class)
                    .checkCast(com.facebook.presto.spi.block.Block.class)
                    .putVariable(sampleWeightsBlock);
        }
        // Get all parameter blocks
        for (int i = 0; i < countInputChannels(parameterMetadatas); i++) {
            body.comment("%s = page.getBlock(inputChannels.get(%d));"parameterVariables.get(i).getName(), i)
                    .append(page)
                    .append(thisVariable.getField(inputChannelsField))
                    .push(i)
                    .invokeInterface(List.class"get"Object.classint.class)
                    .checkCast(Integer.class)
                    .invokeVirtual(Integer.class"intValue"int.class)
                    .invokeVirtual(Page.class"getBlock"com.facebook.presto.spi.block.Block.classint.class)
                    .putVariable(parameterVariables.get(i));
        }
        Block block = generateInputForLoop(stateFieldparameterMetadatasinputFunctionscopeparameterVariablesmasksBlocksampleWeightsBlockcallSiteBindergrouped);
        body.append(block);
        body.ret();
    }
    private static Block generateInputForLoop(
            FieldDefinition stateField,
            List<ParameterMetadataparameterMetadatas,
            MethodHandle inputFunction,
            Scope scope,
            List<VariableparameterVariables,
            Variable masksBlock,
            @Nullable Variable sampleWeightsBlock,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        // For-loop over rows
        Variable page = scope.getVariable("page");
        Variable positionVariable = scope.declareVariable(int.class"position");
        Variable sampleWeightVariable = null;
        if (sampleWeightsBlock != null) {
            sampleWeightVariable = scope.declareVariable(long.class"sampleWeight");
        }
        Variable rowsVariable = scope.declareVariable(int.class"rows");
        Block block = new Block()
                .append(page)
                .invokeVirtual(Page.class"getPositionCount"int.class)
                .putVariable(rowsVariable)
                .initializeVariable(positionVariable);
        if (sampleWeightVariable != null) {
            block.initializeVariable(sampleWeightVariable);
        }
        ByteCodeNode loopBody = generateInvokeInputFunction(scopestateFieldpositionVariablesampleWeightVariableparameterVariablesparameterMetadatasinputFunctioncallSiteBindergrouped);
        //  Wrap with null checks
        List<Booleannullable = new ArrayList<>();
        for (ParameterMetadata metadata : parameterMetadatas) {
            switch (metadata.getParameterType()) {
                case :
                case :
                    nullable.add(false);
                    break;
                case :
                    nullable.add(true);
                    break;
                default// do nothing
            }
        }
        checkState(nullable.size() == parameterVariables.size(), "Number of parameters does not match");
        for (int i = 0; i < parameterVariables.size(); i++) {
            if (!nullable.get(i)) {
                Variable variableDefinition = parameterVariables.get(i);
                loopBody = new IfStatement("if(!%s.isNull(position))"variableDefinition.getName())
                        .condition(new Block()
                                .getVariable(variableDefinition)
                                .getVariable(positionVariable)
                                .invokeInterface(com.facebook.presto.spi.block.Block.class"isNull"boolean.classint.class))
                        .ifFalse(loopBody);
            }
        }
        // Check that sample weight is > 0 (also checks the mask)
        if (sampleWeightVariable != null) {
            loopBody = generateComputeSampleWeightAndCheckGreaterThanZero(loopBodysampleWeightVariablemasksBlocksampleWeightsBlockpositionVariable);
        }
        // Otherwise just check the mask
        else {
            loopBody = new IfStatement("if(testMask(%s, position))"masksBlock.getName())
                    .condition(new Block()
                            .getVariable(masksBlock)
                            .getVariable(positionVariable)
                            .invokeStatic(CompilerOperations.class"testMask"boolean.classcom.facebook.presto.spi.block.Block.classint.class))
                    .ifTrue(loopBody);
        }
        block.append(new ForLoop()
                .initialize(new Block().putVariable(positionVariable, 0))
                .condition(new Block()
                        .getVariable(positionVariable)
                        .getVariable(rowsVariable)
                        .invokeStatic(CompilerOperations.class"lessThan"boolean.classint.classint.class))
                .update(new Block().incrementVariable(positionVariable, (byte) 1))
                .body(loopBody));
        return block;
    }
    private static ByteCodeNode generateComputeSampleWeightAndCheckGreaterThanZero(ByteCodeNode bodyVariable sampleWeightVariable masksVariable sampleWeightsVariable position)
    {
        Block block = new Block()
                .comment("sampleWeight = computeSampleWeight(masks, sampleWeights, position);")
                .getVariable(masks)
                .getVariable(sampleWeights)
                .getVariable(position)
                .invokeStatic(ApproximateUtils.class"computeSampleWeight"long.classcom.facebook.presto.spi.block.Block.classcom.facebook.presto.spi.block.Block.classint.class)
                .putVariable(sampleWeight);
        block.append(new IfStatement("if(sampleWeight > 0)")
                .condition(new Block()
                        .getVariable(sampleWeight)
                        .invokeStatic(CompilerOperations.class"longGreaterThanZero"boolean.classlong.class))
                .ifTrue(body)
                .ifFalse());
        return block;
    }
    private static Block generateInvokeInputFunction(
            Scope scope,
            FieldDefinition stateField,
            Variable position,
            @Nullable Variable sampleWeight,
            List<VariableparameterVariables,
            List<ParameterMetadataparameterMetadatas,
            MethodHandle inputFunction,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        Block block = new Block();
        if (grouped) {
            generateSetGroupIdFromGroupIdsBlock(scopestateFieldblock);
        }
        block.comment("Call input function with unpacked Block arguments");
        Class<?>[] parameters = inputFunction.type().parameterArray();
        int inputChannel = 0;
        for (int i = 0; i < parameters.lengthi++) {
            ParameterMetadata parameterMetadata = parameterMetadatas.get(i);
            switch (parameterMetadata.getParameterType()) {
                case :
                    block.append(scope.getThis().getField(stateField));
                    break;
                case :
                    block.getVariable(position);
                    break;
                case :
                    checkNotNull(sampleWeight"sampleWeight is null");
                    block.getVariable(sampleWeight);
                    break;
                case :
                case :
                    block.getVariable(parameterVariables.get(inputChannel));
                    inputChannel++;
                    break;
                case :
                    Block getBlockByteCode = new Block()
                            .getVariable(parameterVariables.get(inputChannel));
                    pushStackType(scopeblockparameterMetadata.getSqlType(), getBlockByteCodeparameters[i], callSiteBinder);
                    inputChannel++;
                    break;
                default:
                    throw new IllegalArgumentException("Unsupported parameter type: " + parameterMetadata.getParameterType());
            }
        }
        block.append(invoke(callSiteBinder.bind(inputFunction), "input"));
        return block;
    }
    // Assumes that there is a variable named 'position' in the block, which is the current index
    private static void pushStackType(Scope scopeBlock blockType sqlTypeBlock getBlockByteCodeClass<?> parameterCallSiteBinder callSiteBinder)
    {
        Variable position = scope.getVariable("position");
        if (parameter == long.class) {
            block.comment("%s.getLong(block, position)"sqlType.getTypeSignature())
                    .append(constantType(callSiteBindersqlType))
                    .append(getBlockByteCode)
                    .append(position)
                    .invokeInterface(Type.class"getLong"long.classcom.facebook.presto.spi.block.Block.classint.class);
        }
        else if (parameter == double.class) {
            block.comment("%s.getDouble(block, position)"sqlType.getTypeSignature())
                    .append(constantType(callSiteBindersqlType))
                    .append(getBlockByteCode)
                    .append(position)
                    .invokeInterface(Type.class"getDouble"double.classcom.facebook.presto.spi.block.Block.classint.class);
        }
        else if (parameter == boolean.class) {
            block.comment("%s.getBoolean(block, position)"sqlType.getTypeSignature())
                    .append(constantType(callSiteBindersqlType))
                    .append(getBlockByteCode)
                    .append(position)
                    .invokeInterface(Type.class"getBoolean"boolean.classcom.facebook.presto.spi.block.Block.classint.class);
        }
        else if (parameter == Slice.class) {
            block.comment("%s.getBoolean(block, position)"sqlType.getTypeSignature())
                    .append(constantType(callSiteBindersqlType))
                    .append(getBlockByteCode)
                    .append(position)
                    .invokeInterface(Type.class"getSlice"Slice.classcom.facebook.presto.spi.block.Block.classint.class);
        }
        else {
            throw new IllegalArgumentException("Unsupported parameter type: " + parameter.getSimpleName());
        }
    }
    private static void generateAddIntermediateAsCombine(
            ClassDefinition definition,
            FieldDefinition stateField,
            FieldDefinition stateSerializerField,
            FieldDefinition stateFactoryField,
            MethodHandle combineFunction,
            Class<?> singleStateClass,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        MethodDefinition method = declareAddIntermediate(definitiongrouped);
        Scope scope = method.getScope();
        Block body = method.getBody();
        Variable thisVariable = method.getThis();
        Variable block = scope.getVariable("block");
        Variable scratchState = scope.declareVariable(singleStateClass"scratchState");
        Variable position = scope.declareVariable(int.class"position");
        body.comment("scratchState = stateFactory.createSingleState();")
                .append(thisVariable.getField(stateFactoryField))
                .invokeInterface(AccumulatorStateFactory.class"createSingleState"Object.class)
                .checkCast(scratchState.getType())
                .putVariable(scratchState);
        if (grouped) {
            generateEnsureCapacity(scopestateFieldbody);
        }
        Block loopBody = new Block();
        if (grouped) {
            Variable groupIdsBlock = scope.getVariable("groupIdsBlock");
            loopBody.append(thisVariable.getField(stateField).invoke("setGroupId"void.classgroupIdsBlock.invoke("getGroupId"long.classposition)));
        }
        loopBody.append(thisVariable.getField(stateSerializerField).invoke("deserialize"void.classblockpositionscratchState.cast(Object.class)));
        loopBody.comment("combine(state, scratchState)")
                .append(thisVariable.getField(stateField))
                .append(scratchState)
                .append(invoke(callSiteBinder.bind(combineFunction), "combine"));
        body.append(generateBlockNonNullPositionForLoop(scopepositionloopBody))
                .ret();
    }
    private static void generateSetGroupIdFromGroupIdsBlock(Scope scopeFieldDefinition stateFieldBlock block)
    {
        Variable groupIdsBlock = scope.getVariable("groupIdsBlock");
        Variable position = scope.getVariable("position");
        ByteCodeExpression state = scope.getThis().getField(stateField);
        block.append(state.invoke("setGroupId"void.classgroupIdsBlock.invoke("getGroupId"long.classposition)));
    }
    private static void generateEnsureCapacity(Scope scopeFieldDefinition stateFieldBlock block)
    {
        Variable groupIdsBlock = scope.getVariable("groupIdsBlock");
        ByteCodeExpression state = scope.getThis().getField(stateField);
        block.append(state.invoke("ensureCapacity"void.classgroupIdsBlock.invoke("getGroupCount"long.class)));
    }
    private static MethodDefinition declareAddIntermediate(ClassDefinition definitionboolean grouped)
    {
        ImmutableList.Builder<Parameterparameters = ImmutableList.builder();
        if (grouped) {
            parameters.add(arg("groupIdsBlock"GroupByIdBlock.class));
        }
        parameters.add(arg("block"com.facebook.presto.spi.block.Block.class));
        return definition.declareMethod(
                a(),
                "addIntermediate",
                type(void.class),
                parameters.build());
    }
    private static void generateAddIntermediateAsIntermediateInput(
            ClassDefinition definition,
            FieldDefinition stateField,
            List<ParameterMetadataparameterMetadatas,
            MethodHandle intermediateInputFunction,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        MethodDefinition method = declareAddIntermediate(definitiongrouped);
        Scope scope = method.getScope();
        Block body = method.getBody();
        if (grouped) {
            generateEnsureCapacity(scopestateFieldbody);
        }
        Variable positionVariable = scope.declareVariable(int.class"position");
        Block loopBody = generateInvokeInputFunction(scopestateFieldpositionVariablenull, ImmutableList.of(scope.getVariable("block")), parameterMetadatasintermediateInputFunctioncallSiteBindergrouped);
        body.append(generateBlockNonNullPositionForLoop(scopepositionVariableloopBody))
                .ret();
    }
    // Generates a for-loop with a local variable named "position" defined, with the current position in the block,
    // loopBody will only be executed for non-null positions in the Block
    private static Block generateBlockNonNullPositionForLoop(Scope scopeVariable positionVariableBlock loopBody)
    {
        Variable rowsVariable = scope.declareVariable(int.class"rows");
        Variable blockVariable = scope.getVariable("block");
        Block block = new Block()
                .append(blockVariable)
                .invokeInterface(com.facebook.presto.spi.block.Block.class"getPositionCount"int.class)
                .putVariable(rowsVariable);
        IfStatement ifStatement = new IfStatement("if(!block.isNull(position))")
                .condition(new Block()
                        .append(blockVariable)
                        .append(positionVariable)
                        .invokeInterface(com.facebook.presto.spi.block.Block.class"isNull"boolean.classint.class))
                .ifFalse(loopBody);
        block.append(new ForLoop()
                .initialize(positionVariable.set(constantInt(0)))
                .condition(new Block()
                        .append(positionVariable)
                        .append(rowsVariable)
                        .invokeStatic(CompilerOperations.class"lessThan"boolean.classint.classint.class))
                .update(new Block().incrementVariable(positionVariable, (byte) 1))
                .body(ifStatement));
        return block;
    }
    private static void generateGroupedEvaluateIntermediate(ClassDefinition definitionFieldDefinition stateSerializerFieldFieldDefinition stateField)
    {
        Parameter groupId = arg("groupId"int.class);
        Parameter out = arg("out"BlockBuilder.class);
        MethodDefinition method = definition.declareMethod(a(), "evaluateIntermediate"type(void.class), groupIdout);
        Variable thisVariable = method.getThis();
        ByteCodeExpression state = thisVariable.getField(stateField);
        ByteCodeExpression stateSerializer = thisVariable.getField(stateSerializerField);
        method.getBody()
                .append(state.invoke("setGroupId"void.classgroupId.cast(long.class)))
                .append(stateSerializer.invoke("serialize"void.classstate.cast(Object.class), out))
                .ret();
    }
    private static void generateEvaluateIntermediate(ClassDefinition definitionFieldDefinition stateSerializerFieldFieldDefinition stateField)
    {
        Parameter out = arg("out"BlockBuilder.class);
        MethodDefinition method = definition.declareMethod(
                a(),
                "evaluateIntermediate",
                type(void.class),
                out);
        Variable thisVariable = method.getThis();
        ByteCodeExpression stateSerializer = thisVariable.getField(stateSerializerField);
        ByteCodeExpression state = thisVariable.getField(stateField);
        method.getBody()
                .append(stateSerializer.invoke("serialize"void.classstate.cast(Object.class), out))
                .ret();
    }
    private static void generateGroupedEvaluateFinal(
            ClassDefinition definition,
            FieldDefinition confidenceField,
            FieldDefinition stateSerializerField,
            FieldDefinition stateField,
            @Nullable MethodHandle outputFunction,
            boolean approximate,
            CallSiteBinder callSiteBinder)
    {
        Parameter groupId = arg("groupId"int.class);
        Parameter out = arg("out"BlockBuilder.class);
        MethodDefinition method = definition.declareMethod(a(), "evaluateFinal"type(void.class), groupIdout);
        Block body = method.getBody();
        Variable thisVariable = method.getThis();
        ByteCodeExpression state = thisVariable.getField(stateField);
        body.append(state.invoke("setGroupId"void.classgroupId.cast(long.class)));
        if (outputFunction != null) {
            body.comment("output(state, out)");
            body.append(state);
            if (approximate) {
                checkNotNull(confidenceField"confidenceField is null");
                body.append(thisVariable.getField(confidenceField));
            }
            body.append(out);
            body.append(invoke(callSiteBinder.bind(outputFunction), "output"));
        }
        else {
            checkArgument(!approximate"Approximate aggregations must specify an output function");
            ByteCodeExpression stateSerializer = thisVariable.getField(stateSerializerField);
            body.append(stateSerializer.invoke("serialize"void.classstate.cast(Object.class), out));
        }
        body.ret();
    }
    private static void generateEvaluateFinal(
            ClassDefinition definition,
            FieldDefinition confidenceField,
            FieldDefinition stateSerializerField,
            FieldDefinition stateField,
            @Nullable
            MethodHandle outputFunction,
            boolean approximate,
            CallSiteBinder callSiteBinder)
    {
        Parameter out = arg("out"BlockBuilder.class);
        MethodDefinition method = definition.declareMethod(
                a(),
                "evaluateFinal",
                type(void.class),
                out);
        Block body = method.getBody();
        Variable thisVariable = method.getThis();
        ByteCodeExpression state = thisVariable.getField(stateField);
        if (outputFunction != null) {
            body.comment("output(state, out)");
            body.append(state);
            if (approximate) {
                checkNotNull(confidenceField"confidenceField is null");
                body.append(thisVariable.getField(confidenceField));
            }
            body.append(out);
            body.append(invoke(callSiteBinder.bind(outputFunction), "output"));
        }
        else {
            checkArgument(!approximate"Approximate aggregations must specify an output function");
            ByteCodeExpression stateSerializer = thisVariable.getField(stateSerializerField);
            body.append(stateSerializer.invoke("serialize"void.classstate.cast(Object.class), out));
        }
        body.ret();
    }
    private static void generateConstructor(
            ClassDefinition definition,
            FieldDefinition stateSerializerField,
            FieldDefinition stateFactoryField,
            FieldDefinition inputChannelsField,
            FieldDefinition maskChannelField,
            @Nullable FieldDefinition sampleWeightChannelField,
            @Nullable FieldDefinition confidenceField,
            FieldDefinition stateField,
            boolean grouped)
    {
        Parameter stateSerializer = arg("stateSerializer"AccumulatorStateSerializer.class);
        Parameter stateFactory = arg("stateFactory"AccumulatorStateFactory.class);
        Parameter inputChannels = arg("inputChannels"type(List.classInteger.class));
        Parameter maskChannel = arg("maskChannel"type(Optional.classInteger.class));
        Parameter sampleWeightChannel = arg("sampleWeightChannel"type(Optional.classInteger.class));
        Parameter confidence = arg("confidence"double.class);
        MethodDefinition method = definition.declareConstructor(
                a(),
                stateSerializer,
                stateFactory,
                inputChannels,
                maskChannel,
                sampleWeightChannel,
                confidence);
        Block body = method.getBody();
        Variable thisVariable = method.getThis();
        body.comment("super();")
                .append(thisVariable)
                .invokeConstructor(Object.class);
        body.append(thisVariable.setField(stateSerializerFieldgenerateRequireNotNull(stateSerializer)));
        body.append(thisVariable.setField(stateFactoryFieldgenerateRequireNotNull(stateFactory)));
        body.append(thisVariable.setField(inputChannelsFieldgenerateRequireNotNull(inputChannels)));
        body.append(thisVariable.setField(maskChannelFieldgenerateRequireNotNull(maskChannel)));
        if (sampleWeightChannelField != null) {
            body.append(thisVariable.setField(sampleWeightChannelFieldgenerateRequireNotNull(sampleWeightChannel)));
        }
        String createState;
        if (grouped) {
            createState = "createGroupedState";
        }
        else {
            createState = "createSingleState";
        }
        if (confidenceField != null) {
            body.append(thisVariable.setField(confidenceFieldconfidence));
        }
        body.append(thisVariable.setField(stateFieldstateFactory.invoke(createStateObject.class).cast(stateField.getType())));
        body.ret();
    }
    private static ByteCodeExpression generateRequireNotNull(Variable variable)
    {
        return invokeStatic(Objects.class"requireNonNull"Object.classvariable.cast(Object.class), constantString(variable.getName() + " is null"))
                .cast(variable.getType());
    }
New to GrepCode? Check out our FAQ X