Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.gen;
 
 
 import java.util.List;
 
 import static com.facebook.presto.byteCode.Access.FINAL;
 import static com.facebook.presto.byteCode.Access.PRIVATE;
 import static com.facebook.presto.byteCode.Access.PUBLIC;
 import static com.facebook.presto.byteCode.Access.a;
 import static com.facebook.presto.byteCode.Parameter.arg;
 import static com.facebook.presto.byteCode.ParameterizedType.type;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantInt;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantNull;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.notEqual;
 import static com.facebook.presto.sql.gen.CompilerUtils.defineClass;
 import static com.facebook.presto.sql.gen.CompilerUtils.makeClassName;
 import static com.facebook.presto.sql.gen.SqlTypeByteCodeExpression.constantType;
 import static com.google.common.base.Preconditions.checkNotNull;
 
 public class JoinCompiler
 {
     private final LoadingCache<CacheKeyLookupSourceFactorylookupSourceFactories = CacheBuilder.newBuilder().maximumSize(1000).build(
             new CacheLoader<CacheKeyLookupSourceFactory>()
             {
                 @Override
                 public LookupSourceFactory load(CacheKey key)
                         throws Exception
                 {
                     return internalCompileLookupSourceFactory(key.getTypes(), key.getJoinChannels());
                 }
             });
 
     private final LoadingCache<CacheKeyClass<? extends PagesHashStrategy>> hashStrategies = CacheBuilder.newBuilder().maximumSize(1000).build(
             new CacheLoader<CacheKeyClass<? extends PagesHashStrategy>>() {
                 @Override
                 public Class<? extends PagesHashStrategyload(CacheKey key)
                         throws Exception
                 {
                     return internalCompileHashStrategy(key.getTypes(), key.getJoinChannels());
                 }
             });
 
     public LookupSourceFactory compileLookupSourceFactory(List<? extends TypetypesList<IntegerjoinChannels)
     {
         try {
             return .get(new CacheKey(typesjoinChannels));
         }
         catch (ExecutionException | UncheckedExecutionException | ExecutionError e) {
             throw Throwables.propagate(e.getCause());
         }
     }
 
    {
        checkNotNull(types"types is null");
        checkNotNull(joinChannels"joinChannels is null");
        try {
            return new PagesHashStrategyFactory(.get(new CacheKey(typesjoinChannels)));
        }
            throw Throwables.propagate(e.getCause());
        }
    }
    {
        Class<? extends PagesHashStrategypagesHashStrategyClass = internalCompileHashStrategy(typesjoinChannels);
        Class<? extends LookupSourcelookupSourceClass = IsolatedClass.isolateClass(
                new DynamicClassLoader(getClass().getClassLoader()),
                LookupSource.class,
                InMemoryJoinHash.class,
                InMemoryJoinHash.UnvisitedJoinPositionIterator.class);
        return new LookupSourceFactory(lookupSourceClassnew PagesHashStrategyFactory(pagesHashStrategyClass));
    }
    private Class<? extends PagesHashStrategyinternalCompileHashStrategy(List<TypetypesList<IntegerjoinChannels)
    {
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        ClassDefinition classDefinition = new ClassDefinition(
                a(),
                makeClassName("PagesHashStrategy"),
                type(Object.class),
                type(PagesHashStrategy.class));
        List<FieldDefinitionchannelFields = new ArrayList<>();
        for (int i = 0; i < types.size(); i++) {
            FieldDefinition channelField = classDefinition.declareField(a(), "channel_" + itype(List.classcom.facebook.presto.spi.block.Block.class));
            channelFields.add(channelField);
        }
        List<TypejoinChannelTypes = new ArrayList<>();
        List<FieldDefinitionjoinChannelFields = new ArrayList<>();
        for (int i = 0; i < joinChannels.size(); i++) {
            joinChannelTypes.add(types.get(joinChannels.get(i)));
            FieldDefinition channelField = classDefinition.declareField(a(), "joinChannel_" + itype(List.classcom.facebook.presto.spi.block.Block.class));
            joinChannelFields.add(channelField);
        }
        FieldDefinition hashChannelField = classDefinition.declareField(a(), "hashChannel"type(List.classcom.facebook.presto.spi.block.Block.class));
        generateConstructor(classDefinitionjoinChannelschannelFieldsjoinChannelFieldshashChannelField);
        generateGetChannelCountMethod(classDefinitionchannelFields);
        generateAppendToMethod(classDefinitioncallSiteBindertypeschannelFields);
        generateHashPositionMethod(classDefinitioncallSiteBinderjoinChannelTypesjoinChannelFieldshashChannelField);
        generateHashRowMethod(classDefinitioncallSiteBinderjoinChannelTypes);
        generateRowEqualsRowMethod(classDefinitioncallSiteBinderjoinChannelTypes);
        generatePositionEqualsRowMethod(classDefinitioncallSiteBinderjoinChannelTypesjoinChannelFields);
        generatePositionEqualsPositionMethod(classDefinitioncallSiteBinderjoinChannelTypesjoinChannelFields);
        return defineClass(classDefinitionPagesHashStrategy.classcallSiteBinder.getBindings(), getClass().getClassLoader());
    }
    private void generateConstructor(ClassDefinition classDefinition,
            List<IntegerjoinChannels,
            List<FieldDefinitionchannelFields,
            List<FieldDefinitionjoinChannelFields,
            FieldDefinition hashChannelField)
    {
        Parameter channels = arg("channels"type(List.classtype(List.classcom.facebook.presto.spi.block.Block.class)));
        Parameter hashChannel = arg("hashChannel"type(Optional.classInteger.class));
        MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(), channelshashChannel);
        Variable thisVariable = constructorDefinition.getThis();
        Block constructor = constructorDefinition
                .getBody()
                .comment("super();")
                .append(thisVariable)
                .invokeConstructor(Object.class);
        constructor.comment("Set channel fields");
        for (int index = 0; index < channelFields.size(); index++) {
            ByteCodeExpression channel = channels.invoke("get"Object.classconstantInt(index))
                    .cast(type(List.classcom.facebook.presto.spi.block.Block.class));
            constructor.append(thisVariable.setField(channelFields.get(index), channel));
        }
        constructor.comment("Set join channel fields");
        for (int index = 0; index < joinChannelFields.size(); index++) {
            ByteCodeExpression joinChannel = channels.invoke("get"Object.classconstantInt(joinChannels.get(index)))
                    .cast(type(List.classcom.facebook.presto.spi.block.Block.class));
            constructor.append(thisVariable.setField(joinChannelFields.get(index), joinChannel));
        }
        constructor.comment("Set hashChannel");
        constructor.append(new IfStatement()
                .condition(hashChannel.invoke("isPresent"boolean.class))
                .ifTrue(thisVariable.setField(
                        hashChannelField,
                        channels.invoke("get"Object.classhashChannel.invoke("get"Object.class).cast(Integer.class).cast(int.class))))
                .ifFalse(thisVariable.setField(
                        hashChannelField,
                        constantNull(hashChannelField.getType()))));
        constructor.ret();
    }
    private void generateGetChannelCountMethod(ClassDefinition classDefinitionList<FieldDefinitionchannelFields)
    {
        classDefinition.declareMethod(
                a(),
                "getChannelCount",
                type(int.class))
                .getBody()
                .push(channelFields.size())
                .retInt();
    }
    private void generateAppendToMethod(ClassDefinition classDefinitionCallSiteBinder callSiteBinderList<TypetypesList<FieldDefinitionchannelFields)
    {
        Parameter blockIndex = arg("blockIndex"int.class);
        Parameter blockPosition = arg("blockPosition"int.class);
        Parameter pageBuilder = arg("pageBuilder"PageBuilder.class);
        Parameter outputChannelOffset = arg("outputChannelOffset"int.class);
        MethodDefinition method = classDefinition.declareMethod(a(), "appendTo"type(void.class), blockIndexblockPositionpageBuilderoutputChannelOffset);
        Variable thisVariable = method.getThis();
        Block appendToBody = method.getBody();
        for (int index = 0; index < channelFields.size(); index++) {
            Type type = types.get(index);
            ByteCodeExpression typeExpression = constantType(callSiteBindertype);
            ByteCodeExpression block = thisVariable
                    .getField(channelFields.get(index))
                    .invoke("get"Object.classblockIndex)
                    .cast(com.facebook.presto.spi.block.Block.class);
            appendToBody
                    .comment("%s.appendTo(channel_%s.get(blockIndex), blockPosition, pageBuilder.getBlockBuilder(outputChannelOffset + %s));"type.getClass(), indexindex)
                    .append(typeExpression)
                    .append(block)
                    .append(blockPosition)
                    .append(pageBuilder)
                    .append(outputChannelOffset)
                    .push(index)
                    .append(.)
                    .invokeVirtual(PageBuilder.class"getBlockBuilder"BlockBuilder.classint.class)
                    .invokeInterface(Type.class"appendTo"void.classcom.facebook.presto.spi.block.Block.classint.classBlockBuilder.class);
        }
        appendToBody.ret();
    }
    private void generateHashPositionMethod(ClassDefinition classDefinitionCallSiteBinder callSiteBinderList<TypejoinChannelTypesList<FieldDefinitionjoinChannelFieldsFieldDefinition hashChannelField)
    {
        Parameter blockIndex = arg("blockIndex"int.class);
        Parameter blockPosition = arg("blockPosition"int.class);
        MethodDefinition hashPositionMethod = classDefinition.declareMethod(
                a(),
                "hashPosition",
                type(int.class),
                blockIndex,
                blockPosition);
        Variable thisVariable = hashPositionMethod.getThis();
        ByteCodeExpression hashChannel = thisVariable.getField(hashChannelField);
        ByteCodeExpression bigintType = constantType(callSiteBinder.);
        IfStatement ifStatement = new IfStatement();
        ifStatement.condition(notEqual(hashChannelconstantNull(hashChannelField.getType())));
        ifStatement.ifTrue(
                bigintType.invoke(
                        "getLong",
                        long.class,
                        hashChannel.invoke("get"Object.classblockIndex).cast(com.facebook.presto.spi.block.Block.class),
                        blockPosition)
                        .cast(int.class)
                        .ret()
        );
        hashPositionMethod
                .getBody()
                .append(ifStatement);
        Variable resultVariable = hashPositionMethod.getScope().declareVariable(int.class"result");
        hashPositionMethod.getBody().push(0).putVariable(resultVariable);
        for (int index = 0; index < joinChannelTypes.size(); index++) {
            ByteCodeExpression type = constantType(callSiteBinderjoinChannelTypes.get(index));
            ByteCodeExpression block = hashPositionMethod
                    .getThis()
                    .getField(joinChannelFields.get(index))
                    .invoke("get"Object.classblockIndex)
                    .cast(com.facebook.presto.spi.block.Block.class);
            hashPositionMethod
                    .getBody()
                    .getVariable(resultVariable)
                    .push(31)
                    .append(.)
                    .append(typeHashCode(typeblockblockPosition))
                    .append(.)
                    .putVariable(resultVariable);
        }
        hashPositionMethod
                .getBody()
                .getVariable(resultVariable)
                .retInt();
    }
    private void generateHashRowMethod(ClassDefinition classDefinitionCallSiteBinder callSiteBinderList<TypejoinChannelTypes)
    {
        Parameter position = arg("position"int.class);
        Parameter blocks = arg("blocks"com.facebook.presto.spi.block.Block[].class);
        MethodDefinition hashPositionMethod = classDefinition.declareMethod(a(), "hashRow"type(int.class), positionblocks);
        Variable resultVariable = hashPositionMethod.getScope().declareVariable(int.class"result");
        hashPositionMethod.getBody().push(0).putVariable(resultVariable);
        for (int index = 0; index < joinChannelTypes.size(); index++) {
            ByteCodeExpression type = constantType(callSiteBinderjoinChannelTypes.get(index));
            // todo is the case needed
            ByteCodeExpression block = blocks.getElement(index).cast(com.facebook.presto.spi.block.Block.class);
            hashPositionMethod
                    .getBody()
                    .getVariable(resultVariable)
                    .push(31)
                    .append(.)
                    .append(typeHashCode(typeblockposition))
                    .append(.)
                    .putVariable(resultVariable);
        }
        hashPositionMethod
                .getBody()
                .getVariable(resultVariable)
                .retInt();
    }
    private static ByteCodeNode typeHashCode(ByteCodeExpression typeByteCodeExpression blockRefByteCodeExpression blockPosition)
    {
        return new IfStatement()
            .condition(blockRef.invoke("isNull"boolean.classblockPosition))
            .ifTrue(constantInt(0))
            .ifFalse(type.invoke("hash"int.classblockRefblockPosition));
    }
    private void generateRowEqualsRowMethod(
            ClassDefinition classDefinition,
            CallSiteBinder callSiteBinder,
            List<TypejoinChannelTypes)
    {
        MethodDefinition hashPositionMethod = classDefinition.declareMethod(
                a(),
                "rowEqualsRow",
                type(boolean.class),
                arg("leftPosition"int.class),
                arg("leftBlocks"com.facebook.presto.spi.block.Block[].class),
                arg("rightPosition"int.class),
                arg("rightBlocks"com.facebook.presto.spi.block.Block[].class));
        Scope compilerContext = hashPositionMethod.getScope();
        for (int index = 0; index < joinChannelTypes.size(); index++) {
            ByteCodeExpression type = constantType(callSiteBinderjoinChannelTypes.get(index));
            ByteCodeExpression leftBlock = compilerContext
                    .getVariable("leftBlocks")
                    .getElement(index);
            ByteCodeExpression rightBlock = compilerContext
                    .getVariable("rightBlocks")
                    .getElement(index);
            LabelNode checkNextField = new LabelNode("checkNextField");
            hashPositionMethod
                    .getBody()
                    .append(typeEquals(
                            type,
                            leftBlock,
                            compilerContext.getVariable("leftPosition"),
                            rightBlock,
                            compilerContext.getVariable("rightPosition")))
                    .ifTrueGoto(checkNextField)
                    .push(false)
                    .retBoolean()
                    .visitLabel(checkNextField);
        }
        hashPositionMethod
                .getBody()
                .push(true)
                .retInt();
    }
    private void generatePositionEqualsRowMethod(
            ClassDefinition classDefinition,
            CallSiteBinder callSiteBinder,
            List<TypejoinChannelTypes,
            List<FieldDefinitionjoinChannelFields)
    {
        Parameter leftBlockIndex = arg("leftBlockIndex"int.class);
        Parameter leftBlockPosition = arg("leftBlockPosition"int.class);
        Parameter rightPosition = arg("rightPosition"int.class);
        Parameter rightBlocks = arg("rightBlocks"com.facebook.presto.spi.block.Block[].class);
        MethodDefinition hashPositionMethod = classDefinition.declareMethod(
                a(),
                "positionEqualsRow",
                type(boolean.class),
                leftBlockIndex,
                leftBlockPosition,
                rightPosition,
                rightBlocks);
        Variable thisVariable = hashPositionMethod.getThis();
        for (int index = 0; index < joinChannelTypes.size(); index++) {
            ByteCodeExpression type = constantType(callSiteBinderjoinChannelTypes.get(index));
            ByteCodeExpression leftBlock = thisVariable
                    .getField(joinChannelFields.get(index))
                    .invoke("get"Object.classleftBlockIndex)
                    .cast(com.facebook.presto.spi.block.Block.class);
            ByteCodeExpression rightBlock = rightBlocks.getElement(index);
            LabelNode checkNextField = new LabelNode("checkNextField");
            hashPositionMethod
                    .getBody()
                    .append(typeEquals(typeleftBlockleftBlockPositionrightBlockrightPosition))
                    .ifTrueGoto(checkNextField)
                    .push(false)
                    .retBoolean()
                    .visitLabel(checkNextField);
        }
        hashPositionMethod
                .getBody()
                .push(true)
                .retInt();
    }
            ClassDefinition classDefinition,
            CallSiteBinder callSiteBinder,
            List<TypejoinChannelTypes,
            List<FieldDefinitionjoinChannelFields)
    {
        Parameter leftBlockIndex = arg("leftBlockIndex"int.class);
        Parameter leftBlockPosition = arg("leftBlockPosition"int.class);
        Parameter rightBlockIndex = arg("rightBlockIndex"int.class);
        Parameter rightBlockPosition = arg("rightBlockPosition"int.class);
        MethodDefinition hashPositionMethod = classDefinition.declareMethod(
                a(),
                "positionEqualsPosition",
                type(boolean.class),
                leftBlockIndex,
                leftBlockPosition,
                rightBlockIndex,
                rightBlockPosition);
        Variable thisVariable = hashPositionMethod.getThis();
        for (int index = 0; index < joinChannelTypes.size(); index++) {
            ByteCodeExpression type = constantType(callSiteBinderjoinChannelTypes.get(index));
            ByteCodeExpression leftBlock = thisVariable
                    .getField(joinChannelFields.get(index))
                    .invoke("get"Object.classleftBlockIndex)
                    .cast(com.facebook.presto.spi.block.Block.class);
            ByteCodeExpression rightBlock = thisVariable
                    .getField(joinChannelFields.get(index))
                    .invoke("get"Object.classrightBlockIndex)
                    .cast(com.facebook.presto.spi.block.Block.class);
            LabelNode checkNextField = new LabelNode("checkNextField");
            hashPositionMethod
                    .getBody()
                    .append(typeEquals(typeleftBlockleftBlockPositionrightBlockrightBlockPosition))
                    .ifTrueGoto(checkNextField)
                    .push(false)
                    .retBoolean()
                    .visitLabel(checkNextField);
        }
        hashPositionMethod
                .getBody()
                .push(true)
                .retInt();
    }
    private static ByteCodeNode typeEquals(
            ByteCodeExpression type,
            ByteCodeExpression leftBlock,
            ByteCodeExpression leftBlockPosition,
            ByteCodeExpression rightBlock,
            ByteCodeExpression rightBlockPosition)
    {
        IfStatement ifStatement = new IfStatement();
        ifStatement.condition()
                .append(leftBlock.invoke("isNull"boolean.classleftBlockPosition))
                .append(rightBlock.invoke("isNull"boolean.classrightBlockPosition))
                .append(.);
        ifStatement.ifTrue()
                .append(leftBlock.invoke("isNull"boolean.classleftBlockPosition))
                .append(rightBlock.invoke("isNull"boolean.classrightBlockPosition))
                .append(.);
        ifStatement.ifFalse().append(type.invoke("equalTo"boolean.classleftBlockleftBlockPositionrightBlockrightBlockPosition));
        return ifStatement;
    }
    public static class LookupSourceFactory
    {
        private final Constructor<? extends LookupSourceconstructor;
        private final PagesHashStrategyFactory pagesHashStrategyFactory;
        public LookupSourceFactory(Class<? extends LookupSourcelookupSourceClassPagesHashStrategyFactory pagesHashStrategyFactory)
        {
            this. = pagesHashStrategyFactory;
            try {
                 = lookupSourceClass.getConstructor(LongArrayList.classList.classPagesHashStrategy.class);
            }
            catch (NoSuchMethodException e) {
                throw Throwables.propagate(e);
            }
        }
        public LookupSource createLookupSource(LongArrayList addressesList<TypetypesList<List<com.facebook.presto.spi.block.Block>> channelsOptional<IntegerhashChannel)
        {
            PagesHashStrategy pagesHashStrategy = .createPagesHashStrategy(channelshashChannel);
            try {
                return .newInstance(addressestypespagesHashStrategy);
            }
            catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
    }
    public static class PagesHashStrategyFactory
    {
        private final Constructor<? extends PagesHashStrategyconstructor;
        public PagesHashStrategyFactory(Class<? extends PagesHashStrategypagesHashStrategyClass)
        {
            try {
                 = pagesHashStrategyClass.getConstructor(List.classOptional.class);
            }
            catch (NoSuchMethodException e) {
                throw Throwables.propagate(e);
            }
        }
        public PagesHashStrategy createPagesHashStrategy(List<? extends List<com.facebook.presto.spi.block.Block>> channelsOptional<IntegerhashChannel)
        {
            try {
                return .newInstance(channelshashChannel);
            }
            catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
    }
    private static final class CacheKey
    {
        private final List<Typetypes;
        private final List<IntegerjoinChannels;
        private CacheKey(List<? extends TypetypesList<IntegerjoinChannels)
        {
            this. = ImmutableList.copyOf(checkNotNull(types"types is null"));
            this. = ImmutableList.copyOf(checkNotNull(joinChannels"joinChannels is null"));
        }
        private List<TypegetTypes()
        {
            return ;
        }
        private List<IntegergetJoinChannels()
        {
            return ;
        }
        @Override
        public int hashCode()
        {
            return Objects.hash();
        }
        @Override
        public boolean equals(Object obj)
        {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof CacheKey)) {
                return false;
            }
            CacheKey other = (CacheKeyobj;
            return Objects.equals(this.other.types) &&
                    Objects.equals(this.other.joinChannels);
        }
    }
New to GrepCode? Check out our FAQ X