Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.operator.scalar;
 
 
 import java.util.List;
 import java.util.Map;
 
 import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
 import static com.facebook.presto.block.BlockAssertions.createBooleansBlock;
 import static com.facebook.presto.block.BlockAssertions.createDoublesBlock;
 import static com.facebook.presto.block.BlockAssertions.createLongsBlock;
 import static com.facebook.presto.block.BlockAssertions.createStringsBlock;
 import static com.facebook.presto.operator.scalar.FunctionAssertions.TestSplit.createNormalSplit;
 import static com.facebook.presto.operator.scalar.FunctionAssertions.TestSplit.createRecordSetSplit;
 import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
 import static com.facebook.presto.spi.type.BigintType.BIGINT;
 import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
 import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
 import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
 import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.analyzeExpressionsWithSymbols;
 import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput;
 import static com.facebook.presto.sql.planner.LocalExecutionPlanner.toTypes;
 import static com.facebook.presto.sql.planner.optimizations.CanonicalizeExpressions.canonicalizeExpression;
 import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
 import static com.google.common.base.Preconditions.checkNotNull;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.testing.Assertions.assertInstanceOf;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
public final class FunctionAssertions
    private static final ExecutorService EXECUTOR = newCachedThreadPool(daemonThreadsNamed("test-%s"));
    private static final SqlParser SQL_PARSER = new SqlParser();
    private static final Page SOURCE_PAGE = new Page(
            createLongsBlock(1234L),
            createStringsBlock("hello"),
            createDoublesBlock(12.34),
            createBooleansBlock(true),
            createLongsBlock(new DateTime(2001, 8, 22, 3, 4, 5, 321, .).getMillis()),
            createStringsBlock("%el%"),
            createStringsBlock((Stringnull));
    private static final Page ZERO_CHANNEL_PAGE = new Page(1);
    private static final Map<IntegerTypeINPUT_TYPES = ImmutableMap.<IntegerType>builder()
            .put(0, )
            .put(1, )
            .put(2, )
            .put(3, )
            .put(4, )
            .put(5, )
            .put(6, )
            .build();
    private static final Map<SymbolIntegerINPUT_MAPPING = ImmutableMap.<SymbolInteger>builder()
            .put(new Symbol("bound_long"), 0)
            .put(new Symbol("bound_string"), 1)
            .put(new Symbol("bound_double"), 2)
            .put(new Symbol("bound_boolean"), 3)
            .put(new Symbol("bound_timestamp"), 4)
            .put(new Symbol("bound_pattern"), 5)
            .put(new Symbol("bound_null_string"), 6)
            .build();
    private static final Map<SymbolTypeSYMBOL_TYPES = ImmutableMap.<SymbolType>builder()
            .put(new Symbol("bound_long"), )
            .put(new Symbol("bound_string"), )
            .put(new Symbol("bound_double"), )
            .put(new Symbol("bound_boolean"), )
            .put(new Symbol("bound_timestamp"), )
            .put(new Symbol("bound_pattern"), )
            .put(new Symbol("bound_null_string"), )
            .build();
    private static final PageSourceProvider PAGE_SOURCE_PROVIDER = new TestPageSourceProvider();
    private static final PlanNodeId SOURCE_ID = new PlanNodeId("scan");
    private final Session session;
    private final LocalQueryRunner runner;
    private final Metadata metadata;
    private final ExpressionCompiler compiler;
    public FunctionAssertions()
    {
        this();
    }
    public FunctionAssertions(Session session)
    {
        this. = checkNotNull(session"session is null");
         = new LocalQueryRunner(session);
         = .getMetadata();
         = new ExpressionCompiler();
    }
    public Metadata getMetadata()
    {
        return ;
    }
    public FunctionAssertions addFunctions(List<ParametricFunctionfunctionInfos)
    {
        .addFunctions(functionInfos);
        return this;
    }
    public FunctionAssertions addScalarFunctions(Class<?> clazz)
    {
        return this;
    }
    public void assertFunction(String projectionObject expected)
    {
        if (expected instanceof Integer) {
            expected = ((Integerexpected).longValue();
        }
        else if (expected instanceof Slice) {
            expected = ((Sliceexpected).toString();
        }
        Object actual = selectSingleValue(projection);
        try {
            assertEquals(actualexpected);
        }
        catch (Throwable e) {
            throw e;
        }
    }
    public void assertFunctionNull(String projection)
    {
        assertNull(selectSingleValue(projection));
    }
    public void assertInvalidFunction(String projection)
    {
        try {
            assertFunction(projectionnull);
            fail();
        }
        catch (PrestoException e) {
        }
    }
    public void tryEvaluate(String expression)
    {
        tryEvaluate(expression);
    }
    public void tryEvaluate(String expressionSession session)
    {
        selectUniqueValue(expressionsession);
    }
    public void tryEvaluateWithAll(String expressionSession session)
    {
        executeProjectionWithAll(expressionsession);
    }
    private Object selectSingleValue(String projectionExpressionCompiler compiler)
    {
        return selectUniqueValue(projectioncompiler);
    }
    private Object selectUniqueValue(String projectionSession sessionExpressionCompiler compiler)
    {
        List<Objectresults = executeProjectionWithAll(projectionsessioncompiler);
        HashSet<ObjectresultSet = new HashSet<>(results);
        // we should only have a single result
        assertTrue(resultSet.size() == 1, "Expected only one result unique result, but got " + resultSet);
        return Iterables.getOnlyElement(resultSet);
    }
    public List<ObjectexecuteProjectionWithAll(String projectionSession sessionExpressionCompiler compiler)
    {
        checkNotNull(projection"projection is null");
        Expression projectionExpression = createExpression(projection);
        List<Objectresults = new ArrayList<>();
        //
        // If the projection does not need bound values, execute query using full engine
        if (!needsBoundValue(projectionExpression)) {
            MaterializedResult result = .execute("SELECT " + projection);
            assertEquals(result.getTypes().size(), 1);
            assertEquals(result.getMaterializedRows().size(), 1);
            Object queryResult = Iterables.getOnlyElement(result.getMaterializedRows()).getField(0);
            results.add(queryResult);
        }
        // execute as standalone operator
        OperatorFactory operatorFactory = compileFilterProject(projectionExpressioncompiler);
        Object directOperatorValue = selectSingleValue(operatorFactorysession);
        results.add(directOperatorValue);
        // interpret
        Object interpretedValue = selectSingleValue(interpretedFilterProject(projectionExpressionsession));
        results.add(interpretedValue);
        // execute over normal operator
        SourceOperatorFactory scanProjectOperatorFactory = compileScanFilterProject(projectionExpressioncompiler);
        Object scanOperatorValue = selectSingleValue(scanProjectOperatorFactorycreateNormalSplit(), session);
        results.add(scanOperatorValue);
        // execute over record set
        Object recordValue = selectSingleValue(scanProjectOperatorFactorycreateRecordSetSplit(), session);
        results.add(recordValue);
        //
        // If the projection does not need bound values, execute query using full engine
        if (!needsBoundValue(projectionExpression)) {
            MaterializedResult result = .execute("SELECT " + projection);
            assertEquals(result.getTypes().size(), 1);
            assertEquals(result.getMaterializedRows().size(), 1);
            Object queryResult = Iterables.getOnlyElement(result.getMaterializedRows()).getField(0);
            results.add(queryResult);
        }
        return results;
    }
    private Object selectSingleValue(OperatorFactory operatorFactorySession session)
    {
        Operator operator = operatorFactory.createOperator(createDriverContext(session));
        return selectSingleValue(operator);
    }
    private Object selectSingleValue(SourceOperatorFactory operatorFactorySplit splitSession session)
    {
        SourceOperator operator = operatorFactory.createOperator(createDriverContext(session));
        operator.addSplit(split);
        operator.noMoreSplits();
        return selectSingleValue(operator);
    }
    private Object selectSingleValue(Operator operator)
    {
        Page output = getAtMostOnePage(operator);
        assertNotNull(output);
        assertEquals(output.getPositionCount(), 1);
        assertEquals(output.getChannelCount(), 1);
        Type type = operator.getTypes().get(0);
        Block block = output.getBlock(0);
        assertEquals(block.getPositionCount(), 1);
        return type.getObjectValue(.toConnectorSession(), block, 0);
    }
    public void assertFilter(String filterboolean expectedboolean withNoInputColumns)
    {
        assertFilter(filterexpectedwithNoInputColumns);
    }
    private void assertFilter(String filterboolean expectedboolean withNoInputColumnsExpressionCompiler compiler)
    {
        List<Booleanresults = executeFilterWithAll(filterwithNoInputColumnscompiler);
        HashSet<BooleanresultSet = new HashSet<>(results);
        // we should only have a single result
        assertTrue(resultSet.size() == 1, "Expected only [" + expected + "] result unique result, but got " + resultSet);
        assertEquals((boolean) Iterables.getOnlyElement(resultSet), expected);
    }
    private List<BooleanexecuteFilterWithAll(String filterSession sessionboolean executeWithNoInputColumnsExpressionCompiler compiler)
    {
        checkNotNull(filter"filter is null");
        Expression filterExpression = createExpression(filter);
        List<Booleanresults = new ArrayList<>();
        // execute as standalone operator
        OperatorFactory operatorFactory = compileFilterProject(filterExpressioncompiler);
        results.add(executeFilter(operatorFactorysession));
        if (executeWithNoInputColumns) {
            // execute as standalone operator
            operatorFactory = compileFilterWithNoInputColumns(filterExpressioncompiler);
            results.add(executeFilterWithNoInputColumns(operatorFactorysession));
        }
        // interpret
        boolean interpretedValue = executeFilter(interpretedFilterProject(filterExpressionsession));
        results.add(interpretedValue);
        // execute over normal operator
        SourceOperatorFactory scanProjectOperatorFactory = compileScanFilterProject(filterExpressioncompiler);
        boolean scanOperatorValue = executeFilter(scanProjectOperatorFactorycreateNormalSplit(), session);
        results.add(scanOperatorValue);
        // execute over record set
        boolean recordValue = executeFilter(scanProjectOperatorFactorycreateRecordSetSplit(), session);
        results.add(recordValue);
        //
        // If the filter does not need bound values, execute query using full engine
        if (!needsBoundValue(filterExpression)) {
            MaterializedResult result = .execute("SELECT TRUE WHERE " + filter);
            assertEquals(result.getTypes().size(), 1);
            Boolean queryResult;
            if (result.getMaterializedRows().isEmpty()) {
                queryResult = false;
            }
            else {
                assertEquals(result.getMaterializedRows().size(), 1);
                queryResult = (Boolean) Iterables.getOnlyElement(result.getMaterializedRows()).getField(0);
            }
            results.add(queryResult);
        }
        return results;
    }
    public static Expression createExpression(String expressionMetadata metadataMap<SymbolTypesymbolTypes)
    {
        Expression parsedExpression = .createExpression(expression);
        final ExpressionAnalysis analysis = analyzeExpressionsWithSymbols(metadatasymbolTypes, ImmutableList.of(parsedExpression));
        Expression rewrittenExpression = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
        {
            @Override
            public Expression rewriteExpression(Expression nodeVoid contextExpressionTreeRewriter<VoidtreeRewriter)
            {
                Expression rewrittenExpression = treeRewriter.defaultRewrite(nodecontext);
                // cast expression if coercion is registered
                Type coercion = analysis.getCoercion(node);
                if (coercion != null) {
                    rewrittenExpression = new Cast(rewrittenExpressioncoercion.getTypeSignature().toString());
                }
                return rewrittenExpression;
            }
        }, parsedExpression);
        return canonicalizeExpression(rewrittenExpression);
    }
    private static boolean executeFilterWithNoInputColumns(OperatorFactory operatorFactorySession session)
    {
        return executeFilterWithNoInputColumns(operatorFactory.createOperator(createDriverContext(session)));
    }
    private static boolean executeFilter(OperatorFactory operatorFactorySession session)
    {
        return executeFilter(operatorFactory.createOperator(createDriverContext(session)));
    }
    private static boolean executeFilter(SourceOperatorFactory operatorFactorySplit splitSession session)
    {
        SourceOperator operator = operatorFactory.createOperator(createDriverContext(session));
        operator.addSplit(split);
        operator.noMoreSplits();
        return executeFilter(operator);
    }
    private static boolean executeFilter(Operator operator)
    {
        Page page = getAtMostOnePage(operator);
        boolean value;
        if (page != null) {
            assertEquals(page.getPositionCount(), 1);
            assertEquals(page.getChannelCount(), 1);
            assertTrue(operator.getTypes().get(0).getBoolean(page.getBlock(0), 0));
            value = true;
        }
        else {
            value = false;
        }
        return value;
    }
    private static boolean executeFilterWithNoInputColumns(Operator operator)
    {
        Page page = getAtMostOnePage(operator);
        boolean value;
        if (page != null) {
            assertEquals(page.getPositionCount(), 1);
            assertEquals(page.getChannelCount(), 0);
            value = true;
        }
        else {
            value = false;
        }
        return value;
    }
    private static boolean needsBoundValue(Expression projectionExpression)
    {
        final AtomicBoolean hasQualifiedNameReference = new AtomicBoolean();
        projectionExpression.accept(new DefaultTraversalVisitor<VoidVoid>()
        {
            @Override
            protected Void visitQualifiedNameReference(QualifiedNameReference nodeVoid context)
            {
                hasQualifiedNameReference.set(true);
                return null;
            }
        }, null);
        return hasQualifiedNameReference.get();
    }
    private Operator interpretedFilterProject(Expression filterExpression projectionSession session)
    {
        FilterFunction filterFunction = new InterpretedFilterFunction(
                filter,
                ,
                ,
                ,
                ,
                session
        );
        ProjectionFunction projectionFunction = new InterpretedProjectionFunction(
                projection,
                ,
                ,
                ,
                ,
                session
        );
        OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory(0, new GenericPageProcessor(filterFunction, ImmutableList.of(projectionFunction)), toTypes(
                ImmutableList.of(projectionFunction)));
        return operatorFactory.createOperator(createDriverContext(session));
    }
    {
        filter = ExpressionTreeRewriter.rewriteWith(new SymbolToInputRewriter(ImmutableMap.<SymbolInteger>of()), filter);
        IdentityHashMap<ExpressionTypeexpressionTypes = getExpressionTypesFromInput(, ImmutableList.of(filter));
        try {
            PageProcessor processor = compiler.compilePageProcessor(
                    SqlToRowExpressionTranslator.translate(filterexpressionTypesfalse),
                    ImmutableList.<RowExpression>of());
            return new FilterAndProjectOperator.FilterAndProjectOperatorFactory(0, processor, ImmutableList.<Type>of());
        }
        catch (Throwable e) {
            if (e instanceof UncheckedExecutionException) {
                e = e.getCause();
            }
            throw new RuntimeException("Error compiling " + filter + ": " + e.getMessage(), e);
        }
    }
    private OperatorFactory compileFilterProject(Expression filterExpression projectionExpressionCompiler compiler)
    {
        filter = ExpressionTreeRewriter.rewriteWith(new SymbolToInputRewriter(), filter);
        projection = ExpressionTreeRewriter.rewriteWith(new SymbolToInputRewriter(), projection);
        IdentityHashMap<ExpressionTypeexpressionTypes = getExpressionTypesFromInput(, ImmutableList.of(filterprojection));
        try {
            List<RowExpressionprojections = ImmutableList.of(SqlToRowExpressionTranslator.translate(projectionexpressionTypesfalse));
            PageProcessor processor = compiler.compilePageProcessor(
                    SqlToRowExpressionTranslator.translate(filterexpressionTypesfalse),
                    projections);
            return new FilterAndProjectOperator.FilterAndProjectOperatorFactory(0, processor, ImmutableList.of(expressionTypes.get(projection)));
        }
        catch (Throwable e) {
            if (e instanceof UncheckedExecutionException) {
                e = e.getCause();
            }
            throw new RuntimeException("Error compiling " + projection + ": " + e.getMessage(), e);
        }
    }
    {
        filter = ExpressionTreeRewriter.rewriteWith(new SymbolToInputRewriter(), filter);
        projection = ExpressionTreeRewriter.rewriteWith(new SymbolToInputRewriter(), projection);
        IdentityHashMap<ExpressionTypeexpressionTypes = getExpressionTypesFromInput(, ImmutableList.of(filterprojection));
        try {
            CursorProcessor cursorProcessor = compiler.compileCursorProcessor(
                    SqlToRowExpressionTranslator.translate(filterexpressionTypesfalse),
                    ImmutableList.of(SqlToRowExpressionTranslator.translate(projectionexpressionTypesfalse)),
                    );
            PageProcessor pageProcessor = compiler.compilePageProcessor(
                    SqlToRowExpressionTranslator.translate(filterexpressionTypesfalse),
                    ImmutableList.of(SqlToRowExpressionTranslator.translate(projectionexpressionTypesfalse)));
                    0,
                    ,
                    ,
                    cursorProcessor,
                    pageProcessor,
                    ImmutableList.<ColumnHandle>of(),
                    ImmutableList.of(expressionTypes.get(projection)));
        }
        catch (Throwable e) {
            if (e instanceof UncheckedExecutionException) {
                e = e.getCause();
            }
            throw new RuntimeException("Error compiling " + projection + ": " + e.getMessage(), e);
        }
    }
    private static Page getAtMostOnePage(Operator operatorPage sourcePage)
    {
        // add our input page if needed
        if (operator.needsInput()) {
            operator.addInput(sourcePage);
        }
        // try to get the output page
        Page result = operator.getOutput();
        // tell operator to finish
        operator.finish();
        // try to get output until the operator is finished
        while (!operator.isFinished()) {
            // operator should never block
            assertTrue(operator.isBlocked().isDone());
            Page output = operator.getOutput();
            if (output != null) {
                assertNull(result);
                result = output;
            }
        }
        return result;
    }
    private static DriverContext createDriverContext(Session session)
    {
        return new TaskContext(new TaskId("query""stage""task"), session)
                .addPipelineContext(truetrue)
                .addDriverContext();
    }
    private static class TestPageSourceProvider
            implements PageSourceProvider
    {
        @Override
        public ConnectorPageSource createPageSource(Split splitList<ColumnHandlecolumns)
        {
            assertInstanceOf(split.getConnectorSplit(), FunctionAssertions.TestSplit.class);
            FunctionAssertions.TestSplit testSplit = (FunctionAssertions.TestSplitsplit.getConnectorSplit();
            if (testSplit.isRecordSet()) {
                RecordSet records = InMemoryRecordSet.builder(ImmutableList.<Type>of()).addRow(
                        1234L,
                        "hello",
                        12.34,
                        true,
                        new DateTime(2001, 8, 22, 3, 4, 5, 321, .).getMillis(),
                        "%el%",
                        null
                ).build();
                return new RecordPageSource(records);
            }
            else {
                return new FixedPageSource(ImmutableList.of());
            }
        }
    }
    static class TestSplit
            implements ConnectorSplit
    {
        static Split createRecordSetSplit()
        {
            return new Split("test"new TestSplit(true));
        }
        static Split createNormalSplit()
        {
            return new Split("test"new TestSplit(false));
        }
        private final boolean recordSet;
        private TestSplit(boolean recordSet)
        {
            this. = recordSet;
        }
        private boolean isRecordSet()
        {
            return ;
        }
        @Override
        public boolean isRemotelyAccessible()
        {
            return false;
        }
        @Override
        public List<HostAddressgetAddresses()
        {
            return ImmutableList.of();
        }
        @Override
        public Object getInfo()
        {
            return this;
        }
    }
New to GrepCode? Check out our FAQ X