Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.gen;
 
 
 import java.util.List;
 
 import static com.facebook.presto.block.BlockAssertions.assertBlockEquals;
 import static com.facebook.presto.operator.PageAssertions.assertPageEquals;
 import static com.facebook.presto.spi.type.BigintType.BIGINT;
 import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
 import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
 import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
 import static com.facebook.presto.type.TypeUtils.hashPosition;
 import static com.facebook.presto.type.TypeUtils.positionEqualsPosition;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertTrue;
 
 public class TestJoinCompiler
 {
     private static final JoinCompiler joinCompiler = new JoinCompiler();
 
     @DataProvider(name = "hashEnabledValues")
     public static Object[][] hashEnabledValuesProvider()
     {
         return new Object[][] { { true }, { false } };
     }
 
     @Test(dataProvider = "hashEnabledValues")
     public void testSingleChannel(boolean hashEnabled)
             throws Exception
     {
         List<TypejoinTypes = ImmutableList.<Type>of();
         List<IntegerjoinChannels = Ints.asList(0);
 
         // compile a single channel hash strategy
         PagesHashStrategyFactory pagesHashStrategyFactory = .compilePagesHashStrategyFactory(joinTypesjoinChannels);
 
         // create hash strategy with a single channel blocks -- make sure there is some overlap in values
         List<Blockchannel = ImmutableList.of(
                 BlockAssertions.createStringSequenceBlock(10, 20),
                 BlockAssertions.createStringSequenceBlock(20, 30),
                 BlockAssertions.createStringSequenceBlock(15, 25));
 
         Optional<IntegerhashChannel = Optional.empty();
         List<List<Block>> channels = ImmutableList.of(channel);
         if (hashEnabled) {
             ImmutableList.Builder<BlockhashChannelBuilder = ImmutableList.builder();
             for (Block block : channel) {
                 hashChannelBuilder.add(TypeUtils.getHashBlock(joinTypesblock));
             }
             hashChannel = Optional.of(1);
             channels = ImmutableList.of(channelhashChannelBuilder.build());
         }
         PagesHashStrategy hashStrategy = pagesHashStrategyFactory.createPagesHashStrategy(channelshashChannel);
 
         // verify channel count
         assertEquals(hashStrategy.getChannelCount(), 1);
 
         // verify hashStrategy is consistent with equals and hash code from block
         for (int leftBlockIndex = 0; leftBlockIndex < channel.size(); leftBlockIndex++) {
             Block leftBlock = channel.get(leftBlockIndex);
 
             PageBuilder pageBuilder = new PageBuilder(ImmutableList.of());
 
             for (int leftBlockPosition = 0; leftBlockPosition < leftBlock.getPositionCount(); leftBlockPosition++) {
                 // hash code of position must match block hash
                 assertEquals(hashStrategy.hashPosition(leftBlockIndexleftBlockPosition), hashPosition(leftBlockleftBlockPosition));
 
                 // position must be equal to itself
                 assertTrue(hashStrategy.positionEqualsPosition(leftBlockIndexleftBlockPositionleftBlockIndexleftBlockPosition));
 
                 // check equality of every position against every other position in the block
                 for (int rightBlockIndex = 0; rightBlockIndex < channel.size(); rightBlockIndex++) {
                    Block rightBlock = channel.get(rightBlockIndex);
                    for (int rightBlockPosition = 0; rightBlockPosition < rightBlock.getPositionCount(); rightBlockPosition++) {
                        boolean expected = positionEqualsPosition(leftBlockleftBlockPositionrightBlockrightBlockPosition);
                        assertEquals(hashStrategy.positionEqualsRow(leftBlockIndexleftBlockPositionrightBlockPositionrightBlock), expected);
                        assertEquals(hashStrategy.rowEqualsRow(leftBlockPositionnew Block[] {leftBlock}, rightBlockPositionnew Block[] {rightBlock}), expected);
                        assertEquals(hashStrategy.positionEqualsPosition(leftBlockIndexleftBlockPositionrightBlockIndexrightBlockPosition), expected);
                    }
                }
                // check equality of every position against every other position in the block cursor
                for (int rightBlockIndex = 0; rightBlockIndex < channel.size(); rightBlockIndex++) {
                    Block rightBlock = channel.get(rightBlockIndex);
                    for (int rightBlockPosition = 0; rightBlockPosition < rightBlock.getPositionCount(); rightBlockPosition++) {
                        boolean expected = positionEqualsPosition(leftBlockleftBlockPositionrightBlockrightBlockPosition);
                        assertEquals(hashStrategy.positionEqualsRow(leftBlockIndexleftBlockPositionrightBlockPositionrightBlock), expected);
                        assertEquals(hashStrategy.rowEqualsRow(leftBlockPositionnew Block[] {leftBlock}, rightBlockPositionnew Block[] {rightBlock}), expected);
                        assertEquals(hashStrategy.positionEqualsPosition(leftBlockIndexleftBlockPositionrightBlockIndexrightBlockPosition), expected);
                    }
                }
                // write position to output block
                pageBuilder.declarePosition();
                hashStrategy.appendTo(leftBlockIndexleftBlockPositionpageBuilder, 0);
            }
            // verify output block matches
            assertBlockEquals(pageBuilder.build().getBlock(0), leftBlock);
        }
    }
    @Test(dataProvider = "hashEnabledValues")
    public void testMultiChannel(boolean hashEnabled)
            throws Exception
    {
        // compile a single channel hash strategy
        JoinCompiler joinCompiler = new JoinCompiler();
        List<Typetypes = ImmutableList.<Type>of();
        List<TypejoinTypes = ImmutableList.<Type>of();
        List<IntegerjoinChannels = Ints.asList(1, 2, 3, 4);
        // crate hash strategy with a single channel blocks -- make sure there is some overlap in values
        List<BlockextraChannel = ImmutableList.of(
                BlockAssertions.createStringSequenceBlock(10, 20),
                BlockAssertions.createStringSequenceBlock(20, 30),
                BlockAssertions.createStringSequenceBlock(15, 25));
        List<BlockvarcharChannel = ImmutableList.of(
                BlockAssertions.createStringSequenceBlock(10, 20),
                BlockAssertions.createStringSequenceBlock(20, 30),
                BlockAssertions.createStringSequenceBlock(15, 25));
        List<BlocklongChannel = ImmutableList.of(
                BlockAssertions.createLongSequenceBlock(10, 20),
                BlockAssertions.createLongSequenceBlock(20, 30),
                BlockAssertions.createLongSequenceBlock(15, 25));
        List<BlockdoubleChannel = ImmutableList.of(
                BlockAssertions.createDoubleSequenceBlock(10, 20),
                BlockAssertions.createDoubleSequenceBlock(20, 30),
                BlockAssertions.createDoubleSequenceBlock(15, 25));
        List<BlockbooleanChannel = ImmutableList.of(
                BlockAssertions.createBooleanSequenceBlock(10, 20),
                BlockAssertions.createBooleanSequenceBlock(20, 30),
                BlockAssertions.createBooleanSequenceBlock(15, 25));
        Optional<IntegerhashChannel = Optional.empty();
        ImmutableList<List<Block>> channels = ImmutableList.of(extraChannelvarcharChannellongChanneldoubleChannelbooleanChannel);
        List<BlockprecomputedHash = ImmutableList.of();
        if (hashEnabled) {
            ImmutableList.Builder<BlockhashChannelBuilder = ImmutableList.builder();
            for (int i = 0; i < 3; i++) {
                hashChannelBuilder.add(TypeUtils.getHashBlock(joinTypesvarcharChannel.get(i), longChannel.get(i), doubleChannel.get(i), booleanChannel.get(i)));
            }
            hashChannel = Optional.of(5);
            precomputedHash = hashChannelBuilder.build();
            channels = ImmutableList.of(extraChannelvarcharChannellongChanneldoubleChannelbooleanChannelprecomputedHash);
            types = ImmutableList.<Type>of();
        }
        PagesHashStrategyFactory pagesHashStrategyFactory = joinCompiler.compilePagesHashStrategyFactory(typesjoinChannels);
        PagesHashStrategy hashStrategy = pagesHashStrategyFactory.createPagesHashStrategy(channelshashChannel);
        PagesHashStrategy expectedHashStrategy = new SimplePagesHashStrategy(typeschannelsjoinChannelshashChannel);
        // verify channel count
        assertEquals(hashStrategy.getChannelCount(), types.size());
        // verify hashStrategy is consistent with equals and hash code from block
        for (int leftBlockIndex = 0; leftBlockIndex < varcharChannel.size(); leftBlockIndex++) {
            PageBuilder pageBuilder = new PageBuilder(types);
            Block[] leftBlocks = new Block[4];
            leftBlocks[0] = varcharChannel.get(leftBlockIndex);
            leftBlocks[1] = longChannel.get(leftBlockIndex);
            leftBlocks[2] = doubleChannel.get(leftBlockIndex);
            leftBlocks[3] = booleanChannel.get(leftBlockIndex);
            int leftPositionCount = varcharChannel.get(leftBlockIndex).getPositionCount();
            for (int leftBlockPosition = 0; leftBlockPosition < leftPositionCountleftBlockPosition++) {
                // hash code of position must match block hash
                assertEquals(
                        hashStrategy.hashPosition(leftBlockIndexleftBlockPosition),
                        expectedHashStrategy.hashPosition(leftBlockIndexleftBlockPosition));
                // position must be equal to itself
                assertTrue(hashStrategy.positionEqualsPosition(leftBlockIndexleftBlockPositionleftBlockIndexleftBlockPosition));
                // check equality of every position against every other position in the block
                for (int rightBlockIndex = 0; rightBlockIndex < varcharChannel.size(); rightBlockIndex++) {
                    Block rightBlock = varcharChannel.get(rightBlockIndex);
                    for (int rightBlockPosition = 0; rightBlockPosition < rightBlock.getPositionCount(); rightBlockPosition++) {
                        assertEquals(
                                hashStrategy.positionEqualsPosition(leftBlockIndexleftBlockPositionrightBlockIndexrightBlockPosition),
                                expectedHashStrategy.positionEqualsPosition(leftBlockIndexleftBlockPositionrightBlockIndexrightBlockPosition));
                    }
                }
                // check equality of every position against every other position in the block cursor
                for (int rightBlockIndex = 0; rightBlockIndex < varcharChannel.size(); rightBlockIndex++) {
                    Block[] rightBlocks = new Block[4];
                    rightBlocks[0] = varcharChannel.get(rightBlockIndex);
                    rightBlocks[1] = longChannel.get(rightBlockIndex);
                    rightBlocks[2] = doubleChannel.get(rightBlockIndex);
                    rightBlocks[3] = booleanChannel.get(rightBlockIndex);
                    int rightPositionCount = varcharChannel.get(rightBlockIndex).getPositionCount();
                    for (int rightPosition = 0; rightPosition < rightPositionCountrightPosition++) {
                        boolean expected = expectedHashStrategy.positionEqualsRow(leftBlockIndexleftBlockPositionrightPositionrightBlocks);
                        assertEquals(hashStrategy.positionEqualsRow(leftBlockIndexleftBlockPositionrightPositionrightBlocks), expected);
                        assertEquals(hashStrategy.rowEqualsRow(leftBlockPositionleftBlocksrightPositionrightBlocks), expected);
                        assertEquals(hashStrategy.positionEqualsRow(leftBlockIndexleftBlockPositionrightPositionrightBlocks), expected);
                    }
                }
                // write position to output block
                pageBuilder.declarePosition();
                hashStrategy.appendTo(leftBlockIndexleftBlockPositionpageBuilder, 0);
            }
            // verify output block matches
            Page page = pageBuilder.build();
            if (hashEnabled) {
                assertPageEquals(typespagenew Page(
                        extraChannel.get(leftBlockIndex),
                        varcharChannel.get(leftBlockIndex),
                        longChannel.get(leftBlockIndex),
                        doubleChannel.get(leftBlockIndex),
                        booleanChannel.get(leftBlockIndex),
                        precomputedHash.get(leftBlockIndex)));
            }
            else {
                assertPageEquals(typespagenew Page(
                        extraChannel.get(leftBlockIndex),
                        varcharChannel.get(leftBlockIndex),
                        longChannel.get(leftBlockIndex),
                        doubleChannel.get(leftBlockIndex),
                        booleanChannel.get(leftBlockIndex)));
            }
        }
    }
New to GrepCode? Check out our FAQ X