Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.gen;
 
 
 import java.util.List;
 
 import static com.facebook.presto.byteCode.Access.FINAL;
 import static com.facebook.presto.byteCode.Access.PRIVATE;
 import static com.facebook.presto.byteCode.Access.PUBLIC;
 import static com.facebook.presto.byteCode.Access.a;
 import static com.facebook.presto.byteCode.NamedParameterDefinition.arg;
 import static com.facebook.presto.byteCode.ParameterizedType.type;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantInt;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantNull;
 import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.notEqual;
 import static com.facebook.presto.sql.gen.Bootstrap.BOOTSTRAP_METHOD;
 import static com.facebook.presto.sql.gen.CompilerUtils.defineClass;
 import static com.facebook.presto.sql.gen.CompilerUtils.makeClassName;
 import static com.facebook.presto.sql.gen.SqlTypeByteCodeExpression.constantType;
 import static com.google.common.base.Preconditions.checkNotNull;
 
 public class JoinCompiler
 {
     private final LoadingCache<CacheKeyLookupSourceFactorylookupSourceFactories = CacheBuilder.newBuilder().maximumSize(1000).build(
             new CacheLoader<CacheKeyLookupSourceFactory>()
             {
                 @Override
                 public LookupSourceFactory load(CacheKey key)
                         throws Exception
                 {
                     return internalCompileLookupSourceFactory(key.getTypes(), key.getJoinChannels());
                 }
             });
 
     private final LoadingCache<CacheKeyClass<? extends PagesHashStrategy>> hashStrategies = CacheBuilder.newBuilder().maximumSize(1000).build(
             new CacheLoader<CacheKeyClass<? extends PagesHashStrategy>>() {
                 @Override
                 public Class<? extends PagesHashStrategyload(CacheKey key)
                         throws Exception
                 {
                     return internalCompileHashStrategy(key.getTypes(), key.getJoinChannels());
                 }
             });
 
     public LookupSourceFactory compileLookupSourceFactory(List<? extends TypetypesList<IntegerjoinChannels)
     {
         try {
             return .get(new CacheKey(typesjoinChannels));
         }
         catch (ExecutionException | UncheckedExecutionException | ExecutionError e) {
             throw Throwables.propagate(e.getCause());
         }
     }
    {
        checkNotNull(types"types is null");
        checkNotNull(joinChannels"joinChannels is null");
        try {
            return new PagesHashStrategyFactory(.get(new CacheKey(typesjoinChannels)));
        }
            throw Throwables.propagate(e.getCause());
        }
    }
    {
        Class<? extends PagesHashStrategypagesHashStrategyClass = internalCompileHashStrategy(typesjoinChannels);
        Class<? extends LookupSourcelookupSourceClass = IsolatedClass.isolateClass(
                new DynamicClassLoader(getClass().getClassLoader()),
                LookupSource.class,
                InMemoryJoinHash.class);
        return new LookupSourceFactory(lookupSourceClassnew PagesHashStrategyFactory(pagesHashStrategyClass));
    }
    private Class<? extends PagesHashStrategyinternalCompileHashStrategy(List<TypetypesList<IntegerjoinChannels)
    {
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        ClassDefinition classDefinition = new ClassDefinition(new CompilerContext(),
                a(),
                makeClassName("PagesHashStrategy"),
                type(Object.class),
                type(PagesHashStrategy.class));
        List<FieldDefinitionchannelFields = new ArrayList<>();
        for (int i = 0; i < types.size(); i++) {
            FieldDefinition channelField = classDefinition.declareField(a(), "channel_" + itype(List.classcom.facebook.presto.spi.block.Block.class));
            channelFields.add(channelField);
        }
        List<TypejoinChannelTypes = new ArrayList<>();
        List<FieldDefinitionjoinChannelFields = new ArrayList<>();
        for (int i = 0; i < joinChannels.size(); i++) {
            joinChannelTypes.add(types.get(joinChannels.get(i)));
            FieldDefinition channelField = classDefinition.declareField(a(), "joinChannel_" + itype(List.classcom.facebook.presto.spi.block.Block.class));
            joinChannelFields.add(channelField);
        }
        FieldDefinition hashChannelField = classDefinition.declareField(a(), "hashChannel"type(List.classcom.facebook.presto.spi.block.Block.class));
        generateConstructor(classDefinitionjoinChannelschannelFieldsjoinChannelFieldshashChannelField);
        generateGetChannelCountMethod(classDefinitionchannelFields);
        generateAppendToMethod(classDefinitioncallSiteBindertypeschannelFields);
        generateHashPositionMethod(classDefinitioncallSiteBinderjoinChannelTypesjoinChannelFieldshashChannelField);
        generateHashRowMethod(classDefinitioncallSiteBinderjoinChannelTypesjoinChannelFields);
        generatePositionEqualsRowMethod(classDefinitioncallSiteBinderjoinChannelTypesjoinChannelFields);
        generatePositionEqualsPositionMethod(classDefinitioncallSiteBinderjoinChannelTypesjoinChannelFields);
        return defineClass(classDefinitionPagesHashStrategy.classcallSiteBinder.getBindings(), getClass().getClassLoader());
    }
    private void generateConstructor(ClassDefinition classDefinition,
            List<IntegerjoinChannels,
            List<FieldDefinitionchannelFields,
            List<FieldDefinitionjoinChannelFields,
            FieldDefinition hashChannelField)
    {
        CompilerContext compilerContext = new CompilerContext();
        Block constructor = classDefinition.declareConstructor(compilerContext,
                a(),
                arg("channels"type(List.classtype(List.classcom.facebook.presto.spi.block.Block.class))),
                arg("hashChannel"type(Optional.classInteger.class)))
                .getBody()
                .comment("super();")
                .pushThis()
                .invokeConstructor(Object.class);
        constructor.comment("Set channel fields");
        for (int index = 0; index < channelFields.size(); index++) {
            ByteCodeExpression channel = compilerContext.getVariable("channels")
                    .invoke("get"Object.classconstantInt(index))
                    .cast(type(List.classcom.facebook.presto.spi.block.Block.class));
            constructor.append(compilerContext.getVariable("this").setField(channelFields.get(index), channel));
        }
        constructor.comment("Set join channel fields");
        for (int index = 0; index < joinChannelFields.size(); index++) {
            ByteCodeExpression joinChannel = compilerContext.getVariable("channels")
                    .invoke("get"Object.classconstantInt(joinChannels.get(index)))
                    .cast(type(List.classcom.facebook.presto.spi.block.Block.class));
            constructor.append(compilerContext.getVariable("this").setField(joinChannelFields.get(index), joinChannel));
        }
        constructor.comment("Set hashChannel");
        Variable hashChannel = compilerContext.getVariable("hashChannel");
        constructor.append(new IfStatement(
                compilerContext,
                hashChannel.invoke("isPresent"boolean.class),
                compilerContext.getVariable("this").setField(hashChannelFieldcompilerContext.getVariable("channels").invoke("get"Object.classhashChannel.invoke("get"Object.class).cast(Integer.class).cast(int.class))),
                compilerContext.getVariable("this").setField(hashChannelFieldconstantNull(hashChannelField.getType()))
        ));
        constructor.ret();
    }
    private void generateGetChannelCountMethod(ClassDefinition classDefinitionList<FieldDefinitionchannelFields)
    {
        classDefinition.declareMethod(new CompilerContext(),
                a(),
                "getChannelCount",
                type(int.class))
                .getBody()
                .push(channelFields.size())
                .retInt();
    }
    private void generateAppendToMethod(ClassDefinition classDefinitionCallSiteBinder callSiteBinderList<TypetypesList<FieldDefinitionchannelFields)
    {
        CompilerContext compilerContext = new CompilerContext();
        Block appendToBody = classDefinition.declareMethod(compilerContext,
                a(),
                "appendTo",
                type(void.class),
                arg("blockIndex"int.class),
                arg("blockPosition"int.class),
                arg("pageBuilder"PageBuilder.class),
                arg("outputChannelOffset"int.class))
                .getBody();
        for (int index = 0; index < channelFields.size(); index++) {
            Type type = types.get(index);
            ByteCodeExpression typeExpression = constantType(compilerContextcallSiteBindertype);
            ByteCodeExpression block = compilerContext
                    .getVariable("this")
                    .getField(channelFields.get(index))
                    .invoke("get"Object.classcompilerContext.getVariable("blockIndex"))
                    .cast(com.facebook.presto.spi.block.Block.class);
            appendToBody
                    .comment("%s.appendTo(channel_%s.get(blockIndex), blockPosition, pageBuilder.getBlockBuilder(outputChannelOffset + %s));"type.getClass(), indexindex)
                    .append(typeExpression)
                    .append(block)
                    .getVariable("blockPosition")
                    .getVariable("pageBuilder")
                    .getVariable("outputChannelOffset")
                    .push(index)
                    .append(.)
                    .invokeVirtual(PageBuilder.class"getBlockBuilder"BlockBuilder.classint.class)
                    .invokeInterface(Type.class"appendTo"void.classcom.facebook.presto.spi.block.Block.classint.classBlockBuilder.class);
        }
        appendToBody.ret();
    }
    private void generateHashPositionMethod(ClassDefinition classDefinitionCallSiteBinder callSiteBinderList<TypejoinChannelTypesList<FieldDefinitionjoinChannelFieldsFieldDefinition hashChannelField)
    {
        CompilerContext compilerContext = new CompilerContext();
        MethodDefinition hashPositionMethod = classDefinition.declareMethod(compilerContext,
                a(),
                "hashPosition",
                type(int.class),
                arg("blockIndex"int.class),
                arg("blockPosition"int.class));
        Variable thisVariable = compilerContext.getVariable("this");
        ByteCodeExpression hashChannel = thisVariable.getField(hashChannelField);
        ByteCodeExpression bigintType = constantType(compilerContextcallSiteBinder.);
        Variable blockIndex = compilerContext.getVariable("blockIndex");
        Variable blockPosition = compilerContext.getVariable("blockPosition");
        IfStatementBuilder ifStatementBuilder = new IfStatementBuilder(compilerContext);
        ifStatementBuilder.condition(notEqual(hashChannelconstantNull(hashChannelField.getType())));
        ifStatementBuilder.ifTrue(
                bigintType.invoke(
                        "getLong",
                        long.class,
                        hashChannel.invoke("get"Object.classblockIndex).cast(com.facebook.presto.spi.block.Block.class),
                        blockPosition)
                        .cast(int.class)
                        .ret()
        );
        hashPositionMethod
                .getBody()
                .append(ifStatementBuilder.build());
        Variable resultVariable = hashPositionMethod.getCompilerContext().declareVariable(int.class"result");
        hashPositionMethod.getBody().push(0).putVariable(resultVariable);
        for (int index = 0; index < joinChannelTypes.size(); index++) {
            ByteCodeExpression type = constantType(compilerContextcallSiteBinderjoinChannelTypes.get(index));
            ByteCodeExpression block = compilerContext
                    .getVariable("this")
                    .getField(joinChannelFields.get(index))
                    .invoke("get"Object.classcompilerContext.getVariable("blockIndex"))
                    .cast(com.facebook.presto.spi.block.Block.class);
            hashPositionMethod
                    .getBody()
                    .getVariable(resultVariable)
                    .push(31)
                    .append(.)
                    .append(typeHashCode(compilerContexttypeblockcompilerContext.getVariable("blockPosition")))
                    .append(.)
                    .putVariable(resultVariable);
        }
        hashPositionMethod
                .getBody()
                .getVariable(resultVariable)
                .retInt();
    }
    private void generateHashRowMethod(ClassDefinition classDefinitionCallSiteBinder callSiteBinderList<TypejoinChannelTypesList<FieldDefinitionjoinChannelFields)
    {
        CompilerContext compilerContext = new CompilerContext();
        MethodDefinition hashPositionMethod = classDefinition.declareMethod(compilerContext,
                a(),
                "hashRow",
                type(int.class),
                arg("position"int.class),
                arg("blocks"com.facebook.presto.spi.block.Block[].class));
        Variable resultVariable = hashPositionMethod.getCompilerContext().declareVariable(int.class"result");
        hashPositionMethod.getBody().push(0).putVariable(resultVariable);
        for (int index = 0; index < joinChannelTypes.size(); index++) {
            ByteCodeExpression type = constantType(compilerContextcallSiteBinderjoinChannelTypes.get(index));
            ByteCodeExpression block = compilerContext
                    .getVariable("blocks")
                    .getElement(index)
                    .cast(com.facebook.presto.spi.block.Block.class);
            hashPositionMethod
                    .getBody()
                    .getVariable(resultVariable)
                    .push(31)
                    .append(.)
                    .append(typeHashCode(compilerContexttypeblockcompilerContext.getVariable("position")))
                    .append(.)
                    .putVariable(resultVariable);
        }
        hashPositionMethod
                .getBody()
                .getVariable(resultVariable)
                .retInt();
    }
    private static ByteCodeNode typeHashCode(CompilerContext compilerContextByteCodeExpression typeByteCodeExpression blockRefByteCodeExpression blockPosition)
    {
        IfStatementBuilder ifStatementBuilder = new IfStatementBuilder(compilerContext);
        ifStatementBuilder.condition(new Block(compilerContext).append(blockRef.invoke("isNull"boolean.classblockPosition)));
        ifStatementBuilder.ifTrue(new Block(compilerContext).push(0));
        ifStatementBuilder.ifFalse(new Block(compilerContext).append(type.invoke("hash"int.classblockRefblockPosition)));
        return ifStatementBuilder.build();
    }
    private void generatePositionEqualsRowMethod(
            ClassDefinition classDefinition,
            CallSiteBinder callSiteBinder,
            List<TypejoinChannelTypes,
            List<FieldDefinitionjoinChannelFields)
    {
        CompilerContext compilerContext = new CompilerContext();
        MethodDefinition hashPositionMethod = classDefinition.declareMethod(compilerContext,
                a(),
                "positionEqualsRow",
                type(boolean.class),
                arg("leftBlockIndex"int.class),
                arg("leftBlockPosition"int.class),
                arg("rightPosition"int.class),
                arg("rightBlocks"com.facebook.presto.spi.block.Block[].class));
        for (int index = 0; index < joinChannelTypes.size(); index++) {
            ByteCodeExpression type = constantType(compilerContextcallSiteBinderjoinChannelTypes.get(index));
            ByteCodeExpression leftBlock = compilerContext
                    .getVariable("this")
                    .getField(joinChannelFields.get(index))
                    .invoke("get"Object.classcompilerContext.getVariable("leftBlockIndex"))
                    .cast(com.facebook.presto.spi.block.Block.class);
            ByteCodeExpression rightBlock = compilerContext
                    .getVariable("rightBlocks")
                    .getElement(index);
            LabelNode checkNextField = new LabelNode("checkNextField");
            hashPositionMethod
                    .getBody()
                    .append(typeEquals(compilerContext,
                            type,
                            leftBlock,
                            compilerContext.getVariable("leftBlockPosition"),
                            rightBlock,
                            compilerContext.getVariable("rightPosition")))
                    .ifTrueGoto(checkNextField)
                    .push(false)
                    .retBoolean()
                    .visitLabel(checkNextField);
        }
        hashPositionMethod
                .getBody()
                .push(true)
                .retInt();
    }
            ClassDefinition classDefinition,
            CallSiteBinder callSiteBinder,
            List<TypejoinChannelTypes,
            List<FieldDefinitionjoinChannelFields)
    {
        CompilerContext compilerContext = new CompilerContext();
        MethodDefinition hashPositionMethod = classDefinition.declareMethod(compilerContext,
                a(),
                "positionEqualsPosition",
                type(boolean.class),
                arg("leftBlockIndex"int.class),
                arg("leftBlockPosition"int.class),
                arg("rightBlockIndex"int.class),
                arg("rightBlockPosition"int.class));
        for (int index = 0; index < joinChannelTypes.size(); index++) {
            ByteCodeExpression type = constantType(compilerContextcallSiteBinderjoinChannelTypes.get(index));
            Variable blockIndex = compilerContext.getVariable("leftBlockIndex");
            ByteCodeExpression leftBlock = compilerContext
                    .getVariable("this")
                    .getField(joinChannelFields.get(index))
                    .invoke("get"Object.classblockIndex)
                    .cast(com.facebook.presto.spi.block.Block.class);
            ByteCodeExpression rightBlock = compilerContext
                    .getVariable("this")
                    .getField(joinChannelFields.get(index))
                    .invoke("get"Object.classcompilerContext.getVariable("rightBlockIndex"))
                    .cast(com.facebook.presto.spi.block.Block.class);
            LabelNode checkNextField = new LabelNode("checkNextField");
            hashPositionMethod
                    .getBody()
                    .append(typeEquals(compilerContext,
                            type,
                            leftBlock,
                            compilerContext.getVariable("leftBlockPosition"),
                            rightBlock,
                            compilerContext.getVariable("rightBlockPosition")))
                    .ifTrueGoto(checkNextField)
                    .push(false)
                    .retBoolean()
                    .visitLabel(checkNextField);
        }
        hashPositionMethod
                .getBody()
                .push(true)
                .retInt();
    }
    private static ByteCodeNode typeEquals(
            CompilerContext compilerContext,
            ByteCodeExpression type,
            ByteCodeExpression leftBlock,
            ByteCodeExpression leftBlockPosition,
            ByteCodeExpression rightBlock,
            ByteCodeExpression rightBlockPosition)
    {
        IfStatementBuilder ifStatementBuilder = new IfStatementBuilder(compilerContext);
        ifStatementBuilder.condition(new Block(compilerContext)
                .append(leftBlock.invoke("isNull"boolean.classleftBlockPosition))
                .append(rightBlock.invoke("isNull"boolean.classrightBlockPosition))
                .append(.));
        ifStatementBuilder.ifTrue(new Block(compilerContext)
                .append(leftBlock.invoke("isNull"boolean.classleftBlockPosition))
                .append(rightBlock.invoke("isNull"boolean.classrightBlockPosition))
                .append(.));
        ifStatementBuilder.ifFalse(new Block(compilerContext).append(type.invoke("equalTo"boolean.classleftBlockleftBlockPositionrightBlockrightBlockPosition)));
        return ifStatementBuilder.build();
    }
    public static class LookupSourceFactory
    {
        private final Constructor<? extends LookupSourceconstructor;
        private final PagesHashStrategyFactory pagesHashStrategyFactory;
        public LookupSourceFactory(Class<? extends LookupSourcelookupSourceClassPagesHashStrategyFactory pagesHashStrategyFactory)
        {
            this. = pagesHashStrategyFactory;
            try {
                 = lookupSourceClass.getConstructor(LongArrayList.classList.classPagesHashStrategy.classOperatorContext.class);
            }
            catch (NoSuchMethodException e) {
                throw Throwables.propagate(e);
            }
        }
        public LookupSource createLookupSource(LongArrayList addressesList<TypetypesList<List<com.facebook.presto.spi.block.Block>> channelsOptional<IntegerhashChannelOperatorContext operatorContext)
        {
            PagesHashStrategy pagesHashStrategy = .createPagesHashStrategy(channelshashChannel);
            try {
                return .newInstance(addressestypespagesHashStrategyoperatorContext);
            }
            catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
    }
    public static class PagesHashStrategyFactory
    {
        private final Constructor<? extends PagesHashStrategyconstructor;
        public PagesHashStrategyFactory(Class<? extends PagesHashStrategypagesHashStrategyClass)
        {
            try {
                 = pagesHashStrategyClass.getConstructor(List.classOptional.class);
            }
            catch (NoSuchMethodException e) {
                throw Throwables.propagate(e);
            }
        }
        public PagesHashStrategy createPagesHashStrategy(List<? extends List<com.facebook.presto.spi.block.Block>> channelsOptional<IntegerhashChannel)
        {
            try {
                return .newInstance(channelshashChannel);
            }
            catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
    }
    private static final class CacheKey
    {
        private final List<Typetypes;
        private final List<IntegerjoinChannels;
        private CacheKey(List<? extends TypetypesList<IntegerjoinChannels)
        {
            this. = ImmutableList.copyOf(checkNotNull(types"types is null"));
            this. = ImmutableList.copyOf(checkNotNull(joinChannels"joinChannels is null"));
        }
        private List<TypegetTypes()
        {
            return ;
        }
        private List<IntegergetJoinChannels()
        {
            return ;
        }
        @Override
        public int hashCode()
        {
            return Objects.hash();
        }
        @Override
        public boolean equals(Object obj)
        {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof CacheKey)) {
                return false;
            }
            CacheKey other = (CacheKeyobj;
            return Objects.equals(this.other.types) &&
                    Objects.equals(this.other.joinChannels);
        }
    }
New to GrepCode? Check out our FAQ X