Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.gen;
 
 
 import java.util.List;
 
 import static com.facebook.presto.byteCode.Access.FINAL;
 import static com.facebook.presto.byteCode.Access.PRIVATE;
 import static com.facebook.presto.byteCode.Access.PUBLIC;
 import static com.facebook.presto.byteCode.Access.a;
 import static com.facebook.presto.byteCode.NamedParameterDefinition.arg;
 import static com.facebook.presto.byteCode.ParameterizedType.type;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantInt;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantLong;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.newInstance;
 import static com.facebook.presto.sql.gen.Bootstrap.BOOTSTRAP_METHOD;
 import static com.facebook.presto.sql.gen.CompilerUtils.defineClass;
 import static com.facebook.presto.sql.gen.CompilerUtils.makeClassName;
 import static com.facebook.presto.sql.gen.SqlTypeByteCodeExpression.constantType;
 
 public class JoinProbeCompiler
 {
             {
                 @Override
                 public HashJoinOperatorFactoryFactory load(JoinOperatorCacheKey key)
                         throws Exception
                 {
                     return internalCompileJoinOperatorFactory(key.getTypes(), key.getProbeChannels(), key.getProbeHashChannel());
                 }
             });
 
     public OperatorFactory compileJoinOperatorFactory(int operatorId,
             LookupSourceSupplier lookupSourceSupplier,
             List<? extends TypeprobeTypes,
             List<IntegerprobeJoinChannel,
             Optional<IntegerprobeHashChannel,
             boolean enableOuterJoin)
     {
         try {
             HashJoinOperatorFactoryFactory operatorFactoryFactory = .get(new JoinOperatorCacheKey(probeTypesprobeJoinChannelprobeHashChannelenableOuterJoin));
             return operatorFactoryFactory.createHashJoinOperatorFactory(operatorIdlookupSourceSupplierprobeTypesprobeJoinChannelenableOuterJoin);
         }
         catch (ExecutionException | UncheckedExecutionException | ExecutionError e) {
             throw Throwables.propagate(e.getCause());
         }
     }
 
     public HashJoinOperatorFactoryFactory internalCompileJoinOperatorFactory(List<TypetypesList<IntegerprobeJoinChannelOptional<IntegerprobeHashChannel)
     {
         Class<? extends JoinProbejoinProbeClass = compileJoinProbe(typesprobeJoinChannelprobeHashChannel);
 
        ClassDefinition classDefinition = new ClassDefinition(new CompilerContext(),
                a(),
                makeClassName("JoinProbeFactory"),
                type(Object.class),
                type(JoinProbeFactory.class));
        classDefinition.declareDefaultConstructor(a());
        classDefinition.declareMethod(new CompilerContext(),
                a(),
                "createJoinProbe",
                type(JoinProbe.class),
                arg("lookupSource"LookupSource.class),
                arg("page"Page.class))
                .getBody()
                .newObject(joinProbeClass)
                .dup()
                .getVariable("lookupSource")
                .getVariable("page")
                .invokeConstructor(joinProbeClassLookupSource.classPage.class)
                .retObject();
        DynamicClassLoader classLoader = new DynamicClassLoader(joinProbeClass.getClassLoader());
        Class<? extends JoinProbeFactoryjoinProbeFactoryClass = defineClass(classDefinitionJoinProbeFactory.classclassLoader);
        JoinProbeFactory joinProbeFactory;
        try {
            joinProbeFactory = joinProbeFactoryClass.newInstance();
        }
        catch (Exception e) {
            throw Throwables.propagate(e);
        }
        Class<? extends OperatorFactoryoperatorFactoryClass = IsolatedClass.isolateClass(
                classLoader,
                OperatorFactory.class,
                LookupJoinOperatorFactory.class,
                LookupJoinOperator.class);
        return new HashJoinOperatorFactoryFactory(joinProbeFactoryoperatorFactoryClass);
    }
    public JoinProbeFactory internalCompileJoinProbe(List<TypetypesList<IntegerprobeChannelsOptional<IntegerprobeHashChannel)
    {
        return new ReflectionJoinProbeFactory(compileJoinProbe(typesprobeChannelsprobeHashChannel));
    }
    private Class<? extends JoinProbecompileJoinProbe(List<TypetypesList<IntegerprobeChannelsOptional<IntegerprobeHashChannel)
    {
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        ClassDefinition classDefinition = new ClassDefinition(new CompilerContext(),
                a(),
                makeClassName("JoinProbe"),
                type(Object.class),
                type(JoinProbe.class));
        // declare fields
        FieldDefinition lookupSourceField = classDefinition.declareField(a(), "lookupSource"LookupSource.class);
        FieldDefinition positionCountField = classDefinition.declareField(a(), "positionCount"int.class);
        List<FieldDefinitionblockFields = new ArrayList<>();
        for (int i = 0; i < types.size(); i++) {
            FieldDefinition channelField = classDefinition.declareField(a(), "block_" + icom.facebook.presto.spi.block.Block.class);
            blockFields.add(channelField);
        }
        List<FieldDefinitionprobeBlockFields = new ArrayList<>();
        for (int i = 0; i < probeChannels.size(); i++) {
            FieldDefinition channelField = classDefinition.declareField(a(), "probeBlock_" + icom.facebook.presto.spi.block.Block.class);
            probeBlockFields.add(channelField);
        }
        FieldDefinition probeBlocksArrayField = classDefinition.declareField(a(), "probeBlocks"com.facebook.presto.spi.block.Block[].class);
        FieldDefinition probePageField = classDefinition.declareField(a(), "probePage"Page.class);
        FieldDefinition positionField = classDefinition.declareField(a(), "position"int.class);
        FieldDefinition probeHashBlockField = classDefinition.declareField(a(), "probeHashBlock"com.facebook.presto.spi.block.Block.class);
        generateConstructor(classDefinitionprobeChannelsprobeHashChannellookupSourceFieldblockFieldsprobeBlockFieldsprobeBlocksArrayFieldprobePageFieldprobeHashBlockFieldpositionFieldpositionCountField);
        generateGetChannelCountMethod(classDefinitionblockFields.size());
        generateAppendToMethod(classDefinitioncallSiteBindertypesblockFieldspositionField);
        generateAdvanceNextPosition(classDefinitionpositionFieldpositionCountField);
        generateGetCurrentJoinPosition(classDefinitioncallSiteBinderlookupSourceFieldprobePageFieldprobeHashChannelprobeHashBlockFieldpositionField);
        generateCurrentRowContainsNull(classDefinitionprobeBlockFieldspositionField);
        return defineClass(classDefinitionJoinProbe.classcallSiteBinder.getBindings(), getClass().getClassLoader());
    }
    private void generateConstructor(ClassDefinition classDefinition,
            List<IntegerprobeChannels,
            Optional<IntegerprobeHashChannel,
            FieldDefinition lookupSourceField,
            List<FieldDefinitionblockFields,
            List<FieldDefinitionprobeChannelFields,
            FieldDefinition probeBlocksArrayField,
            FieldDefinition probePageField,
            FieldDefinition probeHashBlockField,
            FieldDefinition positionField,
            FieldDefinition positionCountField)
    {
        CompilerContext context = new CompilerContext();
        Block constructor = classDefinition.declareConstructor(context,
                a(),
                arg("lookupSource"LookupSource.class),
                arg("page"Page.class))
                .getBody()
                .comment("super();")
                .pushThis()
                .invokeConstructor(Object.class);
        constructor.comment("this.lookupSource = lookupSource;")
                .append(context.getVariable("this").setField(lookupSourceFieldcontext.getVariable("lookupSource")));
        constructor.comment("this.positionCount = page.getPositionCount();")
                .append(context.getVariable("this").setField(positionCountFieldcontext.getVariable("page").invoke("getPositionCount"int.class)));
        constructor.comment("Set block fields");
        for (int index = 0; index < blockFields.size(); index++) {
            constructor.append(context.getVariable("this").setField(
                    blockFields.get(index),
                    context.getVariable("page").invoke("getBlock"com.facebook.presto.spi.block.Block.classconstantInt(index))));
        }
        constructor.comment("Set probe channel fields");
        for (int index = 0; index < probeChannelFields.size(); index++) {
            constructor.append(context.getVariable("this").setField(
                    probeChannelFields.get(index),
                    context.getVariable("this").getField(blockFields.get(probeChannels.get(index)))));
        }
        constructor.comment("this.probeBlocks = new Block[<probeChannelCount>];");
        constructor
                .pushThis()
                .push(probeChannelFields.size())
                .newArray(com.facebook.presto.spi.block.Block.class)
                .putField(probeBlocksArrayField);
        for (int index = 0; index < probeChannelFields.size(); index++) {
            constructor
                    .pushThis()
                    .getField(probeBlocksArrayField)
                    .push(index)
                    .pushThis()
                    .getField(probeChannelFields.get(index))
                    .putObjectArrayElement();
        }
        ByteCodeExpression page = newInstance(Page.classcontext.getVariable("this").getField(probeBlocksArrayField));
        constructor.comment("this.probePage = new Page(probeBlocks)")
                .append(context.getVariable("this").setField(probePageFieldpage));
        if (probeHashChannel.isPresent()) {
            Integer index = probeHashChannel.get();
            constructor.comment("this.probeHashBlock = blocks[hashChannel.get()]")
                    .append(context.getVariable("this").setField(
                            probeHashBlockField,
                            context.getVariable("this").getField(blockFields.get(index))));
        }
        constructor.comment("this.position = -1;")
                .append(context.getVariable("this").setField(positionFieldconstantInt(-1)));
        constructor.ret();
    }
    private void generateGetChannelCountMethod(ClassDefinition classDefinitionint channelCount)
    {
        classDefinition.declareMethod(new CompilerContext(),
                a(),
                "getChannelCount",
                type(int.class))
                .getBody()
                .push(channelCount)
                .retInt();
    }
    private void generateAppendToMethod(
            ClassDefinition classDefinition,
            CallSiteBinder callSiteBinder,
            List<TypetypesList<FieldDefinitionblockFields,
            FieldDefinition positionField)
    {
        CompilerContext context = new CompilerContext();
        Block appendToBody = classDefinition.declareMethod(context,
                a(),
                "appendTo",
                type(void.class),
                arg("pageBuilder"PageBuilder.class))
                .getBody();
        for (int index = 0; index < blockFields.size(); index++) {
            Type type = types.get(index);
            appendToBody
                    .comment("%s.appendTo(block_%s, position, pageBuilder.getBlockBuilder(%s));"type.getClass(), indexindex)
                    .append(constantType(contextcallSiteBindertype).invoke("appendTo"void.class,
                            context.getVariable("this").getField(blockFields.get(index)),
                            context.getVariable("this").getField(positionField),
                            context.getVariable("pageBuilder").invoke("getBlockBuilder"BlockBuilder.classconstantInt(index))));
        }
        appendToBody.ret();
    }
    private void generateAdvanceNextPosition(ClassDefinition classDefinitionFieldDefinition positionFieldFieldDefinition positionCountField)
    {
        CompilerContext context = new CompilerContext();
        Block advanceNextPositionBody = classDefinition.declareMethod(context,
                a(),
                "advanceNextPosition",
                type(boolean.class))
                .getBody();
        advanceNextPositionBody
                .comment("this.position = this.position + 1;")
                .pushThis()
                .pushThis()
                .getField(positionField)
                .push(1)
                .intAdd()
                .putField(positionField);
        LabelNode lessThan = new LabelNode("lessThan");
        LabelNode end = new LabelNode("end");
        advanceNextPositionBody
                .comment("return position < positionCount;")
                .pushThis()
                .getField(positionField)
                .pushThis()
                .getField(positionCountField)
                .append(JumpInstruction.jumpIfIntLessThan(lessThan))
                .push(false)
                .gotoLabel(end)
                .visitLabel(lessThan)
                .push(true)
                .visitLabel(end)
                .retBoolean();
    }
    private void generateGetCurrentJoinPosition(ClassDefinition classDefinition,
            CallSiteBinder callSiteBinder,
            FieldDefinition lookupSourceField,
            FieldDefinition probePageField,
            Optional<IntegerprobeHashChannel,
            FieldDefinition probeHashBlockField,
            FieldDefinition positionField)
    {
        CompilerContext context = new CompilerContext();
//        Variable thisVariable = context.getVariable("this");
        Block body = classDefinition.declareMethod(context,
                a(),
                "getCurrentJoinPosition",
                type(long.class))
                .getBody()
                .append(new IfStatement(
                        context,
                        context.getVariable("this").invoke("currentRowContainsNull"boolean.class),
                        constantLong(-1).ret(),
                        null
                ));
        ByteCodeExpression position = context.getVariable("this").getField(positionField);
        ByteCodeExpression page = context.getVariable("this").getField(probePageField);
        ByteCodeExpression probeHashBlock = context.getVariable("this").getField(probeHashBlockField);
        if (probeHashChannel.isPresent()) {
            body.append(context.getVariable("this").getField(lookupSourceField).invoke("getJoinPosition"long.class,
                    position,
                    page,
                    constantType(contextcallSiteBinder.).invoke("getLong",
                            long.class,
                            probeHashBlock,
                            position)
                            .cast(int.class)))
                    .retLong();
        }
        else {
            body.append(context.getVariable("this").getField(lookupSourceField).invoke("getJoinPosition"long.classpositionpage)).retLong();
        }
    }
    private void generateCurrentRowContainsNull(ClassDefinition classDefinitionList<FieldDefinitionprobeBlockFieldsFieldDefinition positionField)
    {
        CompilerContext context = new CompilerContext();
        Block body = classDefinition.declareMethod(context,
                a(),
                "currentRowContainsNull",
                type(boolean.class))
                .getBody();
        for (FieldDefinition probeBlockField : probeBlockFields) {
            LabelNode checkNextField = new LabelNode("checkNextField");
            body
                    .append(context.getVariable("this").getField(probeBlockField).invoke("isNull"boolean.classcontext.getVariable("this").getField(positionField)))
                    .ifFalseGoto(checkNextField)
                    .push(true)
                    .retBoolean()
                    .visitLabel(checkNextField);
        }
        body.push(false).retInt();
    }
    public static class ReflectionJoinProbeFactory
            implements JoinProbeFactory
    {
        private final Constructor<? extends JoinProbeconstructor;
        public ReflectionJoinProbeFactory(Class<? extends JoinProbejoinProbeClass)
        {
            try {
                 = joinProbeClass.getConstructor(LookupSource.classPage.class);
            }
            catch (NoSuchMethodException e) {
                throw Throwables.propagate(e);
            }
        }
        @Override
        public JoinProbe createJoinProbe(LookupSource lookupSourcePage page)
        {
            try {
                return .newInstance(lookupSourcepage);
            }
            catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
    }
    private static final class JoinOperatorCacheKey
    {
        private final List<Typetypes;
        private final List<IntegerprobeChannels;
        private final boolean enableOuterJoin;
        private final Optional<IntegerprobeHashChannel;
        private JoinOperatorCacheKey(List<? extends Typetypes,
                List<IntegerprobeChannels,
                Optional<IntegerprobeHashChannel,
                boolean enableOuterJoin)
        {
            this. = probeHashChannel;
            this. = ImmutableList.copyOf(types);
            this. = ImmutableList.copyOf(probeChannels);
            this. = enableOuterJoin;
        }
        private List<TypegetTypes()
        {
            return ;
        }
        private List<IntegergetProbeChannels()
        {
            return ;
        }
        private Optional<IntegergetProbeHashChannel()
        {
            return ;
        }
        @Override
        public int hashCode()
        {
            return Objects.hashCode();
        }
        @Override
        public boolean equals(Object obj)
        {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof JoinOperatorCacheKey)) {
                return false;
            }
            JoinOperatorCacheKey other = (JoinOperatorCacheKeyobj;
            return Objects.equal(this.other.types) &&
                    Objects.equal(this.other.probeChannels) &&
                    Objects.equal(this.other.probeHashChannel) &&
                    Objects.equal(this.other.enableOuterJoin);
        }
    }
    private static class HashJoinOperatorFactoryFactory
    {
        private final JoinProbeFactory joinProbeFactory;
        private final Constructor<? extends OperatorFactoryconstructor;
        private HashJoinOperatorFactoryFactory(JoinProbeFactory joinProbeFactoryClass<? extends OperatorFactoryoperatorFactoryClass)
        {
            this. = joinProbeFactory;
            try {
                 = operatorFactoryClass.getConstructor(int.classLookupSourceSupplier.classList.classboolean.classJoinProbeFactory.class);
            }
            catch (NoSuchMethodException e) {
                throw Throwables.propagate(e);
            }
        }
                int operatorId,
                LookupSourceSupplier lookupSourceSupplier,
                List<? extends TypeprobeTypes,
                List<IntegerprobeJoinChannel,
                boolean enableOuterJoin)
        {
            try {
                return .newInstance(operatorIdlookupSourceSupplierprobeTypesenableOuterJoin);
            }
            catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
    }
    public static void checkState(boolean leftboolean right)
    {
        if (left != right) {
            throw new IllegalStateException();
        }
    }
New to GrepCode? Check out our FAQ X