Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.planner;
 
 import org.joni.Regex;
 
 import java.util.List;
 import java.util.Set;
 
 import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
 import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
 import static com.facebook.presto.sql.planner.LiteralInterpreter.toExpression;
 import static com.facebook.presto.sql.planner.LiteralInterpreter.toExpressions;
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
 import static com.google.common.base.Predicates.instanceOf;
 import static com.google.common.collect.Iterables.any;
 import static java.nio.charset.StandardCharsets.UTF_8;
 
 public class ExpressionInterpreter
 {
     private final Expression expression;
     private final Metadata metadata;
     private final ConnectorSession session;
     private final boolean optimize;
     private final IdentityHashMap<ExpressionTypeexpressionTypes;
 
     private final Visitor visitor;
 
     // identity-based cache for LIKE expressions with constant pattern and escape char
     private final IdentityHashMap<LikePredicateRegexlikePatternCache = new IdentityHashMap<>();
     private final IdentityHashMap<InListExpressionSet<Object>> inListCache = new IdentityHashMap<>();
 
     public static ExpressionInterpreter expressionInterpreter(Expression expressionMetadata metadataSession sessionIdentityHashMap<ExpressionTypeexpressionTypes)
     {
         checkNotNull(expression"expression is null");
        checkNotNull(metadata"metadata is null");
        checkNotNull(session"session is null");
        return new ExpressionInterpreter(expressionmetadatasessionexpressionTypesfalse);
    }
    public static ExpressionInterpreter expressionOptimizer(Expression expressionMetadata metadataSession sessionIdentityHashMap<ExpressionTypeexpressionTypes)
    {
        checkNotNull(expression"expression is null");
        checkNotNull(metadata"metadata is null");
        checkNotNull(session"session is null");
        return new ExpressionInterpreter(expressionmetadatasessionexpressionTypestrue);
    }
    private ExpressionInterpreter(Expression expressionMetadata metadataSession sessionIdentityHashMap<ExpressionTypeexpressionTypesboolean optimize)
    {
        this. = expression;
        this. = metadata;
        this. = session.toConnectorSession();
        this. = expressionTypes;
        this. = optimize;
        this. = new Visitor();
    }
    public Object evaluate(RecordCursor inputs)
    {
        checkState(!"evaluate(RecordCursor) not allowed for optimizer");
        return .process(inputs);
    }
    public Object evaluate(int positionBlock... inputs)
    {
        checkState(!"evaluate(int, Block...) not allowed for optimizer");
        return .process(new PagePositionContext(positioninputs));
    }
    public Object optimize(SymbolResolver inputs)
    {
        checkState("evaluate(SymbolResolver) not allowed for interpreter");
        return .process(inputs);
    }
    @SuppressWarnings("FloatingPointEquality")
    private class Visitor
            extends AstVisitor<ObjectObject>
    {
        @Override
        public Object visitInputReference(InputReference nodeObject context)
        {
            Type type = .get(node);
            int channel = node.getChannel();
            if (context instanceof PagePositionContext) {
                PagePositionContext pagePositionContext = (PagePositionContextcontext;
                int position = pagePositionContext.getPosition();
                Block block = pagePositionContext.getBlock(channel);
                if (block.isNull(position)) {
                    return null;
                }
                Class<?> javaType = type.getJavaType();
                if (javaType == boolean.class) {
                    return type.getBoolean(blockposition);
                }
                else if (javaType == long.class) {
                    return type.getLong(blockposition);
                }
                else if (javaType == double.class) {
                    return type.getDouble(blockposition);
                }
                else if (javaType == Slice.class) {
                    return type.getSlice(blockposition);
                }
                else {
                    throw new UnsupportedOperationException("not yet implemented");
                }
            }
            else if (context instanceof RecordCursor) {
                RecordCursor cursor = (RecordCursorcontext;
                if (cursor.isNull(channel)) {
                    return null;
                }
                Class<?> javaType = type.getJavaType();
                if (javaType == boolean.class) {
                    return cursor.getBoolean(channel);
                }
                else if (javaType == long.class) {
                    return cursor.getLong(channel);
                }
                else if (javaType == double.class) {
                    return cursor.getDouble(channel);
                }
                else if (javaType == Slice.class) {
                    return cursor.getSlice(channel);
                }
                else {
                    throw new UnsupportedOperationException("not yet implemented");
                }
            }
            throw new UnsupportedOperationException("Inputs or cursor myst be set");
        }
        @Override
        protected Object visitQualifiedNameReference(QualifiedNameReference nodeObject context)
        {
            if (node.getName().getPrefix().isPresent()) {
                // not a symbol
                return node;
            }
            Symbol symbol = Symbol.fromQualifiedName(node.getName());
            return ((SymbolResolvercontext).getValue(symbol);
        }
        @Override
        protected Object visitLiteral(Literal nodeObject context)
        {
            return LiteralInterpreter.evaluate(node);
        }
        @Override
        protected Object visitIsNullPredicate(IsNullPredicate nodeObject context)
        {
            Object value = process(node.getValue(), context);
            if (value instanceof Expression) {
                return new IsNullPredicate(toExpression(value.get(node.getValue())));
            }
            return value == null;
        }
        @Override
        protected Object visitIsNotNullPredicate(IsNotNullPredicate nodeObject context)
        {
            Object value = process(node.getValue(), context);
            if (value instanceof Expression) {
                return new IsNotNullPredicate(toExpression(value.get(node.getValue())));
            }
            return value != null;
        }
        @Override
        protected Object visitSearchedCaseExpression(SearchedCaseExpression nodeObject context)
        {
            Expression resultClause = node.getDefaultValue().orElse(null);
            for (WhenClause whenClause : node.getWhenClauses()) {
                Object value = process(whenClause.getOperand(), context);
                if (value instanceof Expression) {
                    // TODO: optimize this case
                    return node;
                }
                if (..equals(value)) {
                    resultClause = whenClause.getResult();
                    break;
                }
            }
            if (resultClause == null) {
                return null;
            }
            Object result = process(resultClausecontext);
            if (result instanceof Expression) {
                return node;
            }
            return result;
        }
        @Override
        protected Object visitSimpleCaseExpression(SimpleCaseExpression nodeObject context)
        {
            Object operand = process(node.getOperand(), context);
            if (operand instanceof Expression) {
                // TODO: optimize this case
                return node;
            }
            Expression resultClause = node.getDefaultValue().orElse(null);
            if (operand != null) {
                for (WhenClause whenClause : node.getWhenClauses()) {
                    Object value = process(whenClause.getOperand(), context);
                    if (value == null) {
                        continue;
                    }
                    if (value instanceof Expression) {
                        // TODO: optimize this case
                        return node;
                    }
                    if ((BooleaninvokeOperator(.types(node.getOperand(), whenClause.getOperand()), ImmutableList.of(operandvalue))) {
                        resultClause = whenClause.getResult();
                        break;
                    }
                }
            }
            if (resultClause == null) {
                return null;
            }
            Object result = process(resultClausecontext);
            if (result instanceof Expression) {
                return node;
            }
            return result;
        }
        @Override
        protected Object visitCoalesceExpression(CoalesceExpression nodeObject context)
        {
            for (Expression expression : node.getOperands()) {
                Object value = process(expressioncontext);
                if (value instanceof Expression) {
                    // TODO: optimize this case
                    return node;
                }
                if (value != null) {
                    return value;
                }
            }
            return null;
        }
        @Override
        protected Object visitInPredicate(InPredicate nodeObject context)
        {
            Object value = process(node.getValue(), context);
            if (value == null) {
                return null;
            }
            Expression valueListExpression = node.getValueList();
            if (!(valueListExpression instanceof InListExpression)) {
                if (!) {
                    throw new UnsupportedOperationException("IN predicate value list type not yet implemented: " + valueListExpression.getClass().getName());
                }
                return node;
            }
            InListExpression valueList = (InListExpressionvalueListExpression;
            Set<Objectset = .get(valueList);
            // We use the presence of the node in the map to indicate that we've already done
            // the analysis below. If the value is null, it means that we can't apply the HashSet
            // optimization
            if (!.containsKey(valueList)) {
                if (Iterables.all(valueList.getValues(), ExpressionInterpreter::isNullLiteral)) {
                    // if all elements are constant, create a set with them
                    set = new HashSet<>();
                    for (Expression expression : valueList.getValues()) {
                        set.add(process(expressioncontext));
                    }
                }
                .put(valueListset);
            }
            if (set != null && !(value instanceof Expression)) {
                return set.contains(value);
            }
            boolean hasUnresolvedValue = false;
            if (value instanceof Expression) {
                hasUnresolvedValue = true;
            }
            boolean hasNullValue = false;
            boolean found = false;
            List<Objectvalues = new ArrayList<>(valueList.getValues().size());
            List<Typetypes = new ArrayList<>(valueList.getValues().size());
            for (Expression expression : valueList.getValues()) {
                Object inValue = process(expressioncontext);
                if (value instanceof Expression || inValue instanceof Expression) {
                    hasUnresolvedValue = true;
                    values.add(inValue);
                    types.add(.get(expression));
                    continue;
                }
                if (inValue == null) {
                    hasNullValue = true;
                }
                else if (!found && (BooleaninvokeOperator(.types(node.getValue(), expression), ImmutableList.of(valueinValue))) {
                    // in does not short-circuit so we must evaluate all value in the list
                    found = true;
                }
            }
            if (found) {
                return true;
            }
            if (hasUnresolvedValue) {
                Type type = .get(node.getValue());
                return new InPredicate(toExpression(valuetype), new InListExpression(toExpressions(valuestypes)));
            }
            if (hasNullValue) {
                return null;
            }
            return false;
        }
        @Override
        protected Object visitArithmeticUnary(ArithmeticUnaryExpression nodeObject context)
        {
            Object value = process(node.getValue(), context);
            if (value == null) {
                return null;
            }
            if (value instanceof Expression) {
                return new ArithmeticUnaryExpression(node.getSign(), toExpression(value.get(node.getValue())));
            }
            switch (node.getSign()) {
                case :
                    return value;
                case :
                    FunctionInfo operatorInfo = .resolveOperator(.types(node.getValue()));
                    MethodHandle handle = operatorInfo.getMethodHandle();
                    if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) {
                        handle = handle.bindTo();
                    }
                    try {
                        return handle.invokeWithArguments(value);
                    }
                    catch (Throwable throwable) {
                        Throwables.propagateIfInstanceOf(throwableRuntimeException.class);
                        Throwables.propagateIfInstanceOf(throwableError.class);
                        throw new RuntimeException(throwable.getMessage(), throwable);
                    }
            }
            throw new UnsupportedOperationException("Unsupported unary operator: " + node.getSign());
        }
        @Override
        protected Object visitArithmeticBinary(ArithmeticBinaryExpression nodeObject context)
        {
            Object left = process(node.getLeft(), context);
            if (left == null) {
                return null;
            }
            Object right = process(node.getRight(), context);
            if (right == null) {
                return null;
            }
            if (hasUnresolvedValue(leftright)) {
                return new ArithmeticBinaryExpression(node.getType(), toExpression(left.get(node.getLeft())), toExpression(right.get(node.getRight())));
            }
            return invokeOperator(OperatorType.valueOf(node.getType().name()), types(node.getLeft(), node.getRight()), ImmutableList.of(leftright));
        }
        @Override
        protected Object visitComparisonExpression(ComparisonExpression nodeObject context)
        {
            ComparisonExpression.Type type = node.getType();
            Object left = process(node.getLeft(), context);
            if (left == null && !(type == ..)) {
                return null;
            }
            Object right = process(node.getRight(), context);
            if (type == ..) {
                if (left == null && right == null) {
                    return false;
                }
                else if (left == null || right == null) {
                    return true;
                }
            }
            else if (right == null) {
                return null;
            }
            if (hasUnresolvedValue(leftright)) {
                return new ComparisonExpression(typetoExpression(left.get(node.getLeft())), toExpression(right.get(node.getRight())));
            }
            if (type == ..) {
                type = ..;
            }
            return invokeOperator(OperatorType.valueOf(type.name()), types(node.getLeft(), node.getRight()), ImmutableList.of(leftright));
        }
        @Override
        protected Object visitBetweenPredicate(BetweenPredicate nodeObject context)
        {
            Object value = process(node.getValue(), context);
            if (value == null) {
                return null;
            }
            Object min = process(node.getMin(), context);
            if (min == null) {
                return null;
            }
            Object max = process(node.getMax(), context);
            if (max == null) {
                return null;
            }
            if (hasUnresolvedValue(valueminmax)) {
                return new BetweenPredicate(
                        toExpression(value.get(node.getValue())),
                        toExpression(min.get(node.getMin())),
                        toExpression(max.get(node.getMax())));
            }
            return invokeOperator(.types(node.getValue(), node.getMin(), node.getMax()), ImmutableList.of(valueminmax));
        }
        @Override
        protected Object visitNullIfExpression(NullIfExpression nodeObject context)
        {
            Object first = process(node.getFirst(), context);
            if (first == null) {
                return null;
            }
            Object second = process(node.getSecond(), context);
            if (second == null) {
                return first;
            }
            Type firstType = .get(node.getFirst());
            Type secondType = .get(node.getSecond());
            if (hasUnresolvedValue(firstsecond)) {
                return new NullIfExpression(toExpression(firstfirstType), toExpression(secondsecondType));
            }
            Type commonType = FunctionRegistry.getCommonSuperType(firstTypesecondType).get();
            FunctionInfo firstCast = .getFunctionRegistry().getCoercion(firstTypecommonType);
            FunctionInfo secondCast = .getFunctionRegistry().getCoercion(secondTypecommonType);
            // cast(first as <common type>) == cast(second as <common type>)
            boolean equal = (BooleaninvokeOperator(
                    .,
                    ImmutableList.of(commonTypecommonType),
                    ImmutableList.of(
                            invoke(firstCast.getMethodHandle(), ImmutableList.of(first)),
                            invoke(secondCast.getMethodHandle(), ImmutableList.of(second))));
            if (equal) {
                return null;
            }
            else {
                return first;
            }
        }
        @Override
        protected Object visitNotExpression(NotExpression nodeObject context)
        {
            Object value = process(node.getValue(), context);
            if (value == null) {
                return null;
            }
            if (value instanceof Expression) {
                return new NotExpression(toExpression(value.get(node.getValue())));
            }
            return !(Booleanvalue;
        }
        @Override
        protected Object visitLogicalBinaryExpression(LogicalBinaryExpression nodeObject context)
        {
            Object left = process(node.getLeft(), context);
            Object right = process(node.getRight(), context);
            switch (node.getType()) {
                case : {
                    // if either left or right is false, result is always false regardless of nulls
                    if (..equals(left) || ..equals(right)) {
                        return left;
                    }
                    if (..equals(right) || ..equals(left)) {
                        return right;
                    }
                }
                case : {
                    // if either left or right is true, result is always true regardless of nulls
                    if (..equals(left) || ..equals(right)) {
                        return left;
                    }
                    if (..equals(right) || ..equals(left)) {
                        return right;
                    }
                }
            }
            if (left == null && right == null) {
                return null;
            }
            return new LogicalBinaryExpression(node.getType(),
                    toExpression(left.get(node.getLeft())),
                    toExpression(right.get(node.getRight())));
        }
        @Override
        protected Object visitBooleanLiteral(BooleanLiteral nodeObject context)
        {
            return node.equals(.);
        }
        @Override
        protected Object visitFunctionCall(FunctionCall nodeObject context)
        {
            List<TypeargumentTypes = new ArrayList<>();
            List<ObjectargumentValues = new ArrayList<>();
            for (Expression expression : node.getArguments()) {
                Object value = process(expressioncontext);
                Type type = .get(expression);
                argumentValues.add(value);
                argumentTypes.add(type);
            }
            FunctionInfo function = .resolveFunction(node.getName(), Lists.transform(argumentTypes, Type::getTypeSignature), false);
            for (int i = 0; i < argumentValues.size(); i++) {
                Object value = argumentValues.get(i);
                if (value == null && !function.getNullableArguments().get(i)) {
                    return null;
                }
            }
            // do not optimize non-deterministic functions
            if ( && (!function.isDeterministic() || hasUnresolvedValue(argumentValues))) {
                return new FunctionCall(node.getName(), node.getWindow(), node.isDistinct(), toExpressions(argumentValuesargumentTypes));
            }
            return invoke(function.getMethodHandle(), argumentValues);
        }
        @Override
        protected Object visitLikePredicate(LikePredicate nodeObject context)
        {
            Object value = process(node.getValue(), context);
            if (value == null) {
                return null;
            }
            if (value instanceof Slice &&
                    node.getPattern() instanceof StringLiteral &&
                    (node.getEscape() instanceof StringLiteral || node.getEscape() == null)) {
                // fast path when we know the pattern and escape are constant
                return LikeFunctions.like((SlicevaluegetConstantPattern(node));
            }
            Object pattern = process(node.getPattern(), context);
            if (pattern == null) {
                return null;
            }
            Object escape = null;
            if (node.getEscape() != null) {
                escape = process(node.getEscape(), context);
                if (escape == null) {
                    return null;
                }
            }
            if (value instanceof Slice &&
                    pattern instanceof Slice &&
                    (escape == null || escape instanceof Slice)) {
                Regex regex;
                if (escape == null) {
                    regex = LikeFunctions.likePattern((Slicepattern);
                }
                else {
                    regex = LikeFunctions.likePattern((Slicepattern, (Sliceescape);
                }
                return LikeFunctions.like((Slicevalueregex);
            }
            // if pattern is a constant without % or _ replace with a comparison
            if (pattern instanceof Slice && escape == null) {
                String stringPattern = ((Slicepattern).toString();
                if (!stringPattern.contains("%") && !stringPattern.contains("_")) {
                    return new ComparisonExpression(..,
                            toExpression(value.get(node.getValue())),
                            toExpression(pattern.get(node.getPattern())));
                }
            }
            Expression optimizedEscape = null;
            if (node.getEscape() != null) {
                optimizedEscape = toExpression(escape.get(node.getEscape()));
            }
            return new LikePredicate(
                    toExpression(value.get(node.getValue())),
                    toExpression(pattern.get(node.getPattern())),
                    optimizedEscape);
        }
        private Regex getConstantPattern(LikePredicate node)
        {
            Regex result = .get(node);
            if (result == null) {
                StringLiteral pattern = (StringLiteralnode.getPattern();
                StringLiteral escape = (StringLiteralnode.getEscape();
                if (escape == null) {
                    result = LikeFunctions.likePattern(pattern.getSlice());
                }
                else {
                    result = LikeFunctions.likePattern(pattern.getSlice(), escape.getSlice());
                }
                .put(noderesult);
            }
            return result;
        }
        @Override
        public Object visitCast(Cast nodeObject context)
        {
            Object value = process(node.getExpression(), context);
            if (value instanceof Expression) {
                return new Cast((Expressionvaluenode.getType(), node.isSafe());
            }
            // hack!!! don't optimize CASTs for types that cannot be represented in the SQL AST
            // TODO: this will not be an issue when we migrate to RowExpression tree for this, which allows arbitrary literals.
            if ( && !FunctionRegistry.isSupportedLiteralType(.get(node))) {
                return new Cast(toExpression(value.get(node.getExpression())), node.getType(), node.isSafe());
            }
            if (value == null) {
                return null;
            }
            Type type = .getType(parseTypeSignature(node.getType()));
            if (type == null) {
                throw new IllegalArgumentException("Unsupported type: " + node.getType());
            }
            FunctionInfo operatorInfo = .getFunctionRegistry().getCoercion(.get(node.getExpression()), type);
            try {
                return invoke(operatorInfo.getMethodHandle(), ImmutableList.of(value));
            }
            catch (RuntimeException e) {
                if (node.isSafe()) {
                    return null;
                }
                throw e;
            }
        }
        @Override
        protected Object visitArrayConstructor(ArrayConstructor nodeObject context)
        {
            return visitFunctionCall(new FunctionCall(QualifiedName.of(.), node.getValues()), context);
        }
        @Override
        protected Object visitRow(Row nodeObject context)
        {
            throw new UnsupportedOperationException("Row expressions not yet supported");
        }
        @Override
        protected Object visitSubscriptExpression(SubscriptExpression nodeObject context)
        {
            Object base = process(node.getBase(), context);
            if (base == null) {
                return null;
            }
            Object index = process(node.getIndex(), context);
            if (index == null) {
                return null;
            }
            if (hasUnresolvedValue(baseindex)) {
                return new SubscriptExpression(toExpression(base.get(node.getBase())), toExpression(index.get(node.getIndex())));
            }
            return invokeOperator(.types(node.getBase(), node.getIndex()), ImmutableList.of(baseindex));
        }
        @Override
        protected Object visitExpression(Expression nodeObject context)
        {
            throw new PrestoException("not yet implemented: " + node.getClass().getName());
        }
        @Override
        protected Object visitNode(Node nodeObject context)
        {
            throw new UnsupportedOperationException("Evaluator visitor can only handle Expression nodes");
        }
        private List<Typetypes(Expression... types)
        {
            return ImmutableList.copyOf(Iterables.transform(ImmutableList.copyOf(types), Functions.forMap()));
        }
        private boolean hasUnresolvedValue(Object... values)
        {
            return hasUnresolvedValue(ImmutableList.copyOf(values));
        }
        private boolean hasUnresolvedValue(List<Objectvalues)
        {
            return any(valuesinstanceOf(Expression.class));
        }
        private Object invokeOperator(OperatorType operatorTypeList<? extends TypeargumentTypesList<ObjectargumentValues)
        {
            FunctionInfo operatorInfo = .resolveOperator(operatorTypeargumentTypes);
            return invoke(operatorInfo.getMethodHandle(), argumentValues);
        }
    }
    private static class PagePositionContext
    {
        private final int position;
        private final Block[] blocks;
        private PagePositionContext(int positionBlock[] blocks)
        {
            this. = position;
            this. = blocks;
        }
        public Block getBlock(int channel)
        {
            return [channel];
        }
        public int getPosition()
        {
            return ;
        }
    }
    public static Object invoke(ConnectorSession sessionMethodHandle handleList<ObjectargumentValues)
    {
        if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) {
            handle = handle.bindTo(session);
        }
        try {
            return handle.invokeWithArguments(argumentValues);
        }
        catch (Throwable throwable) {
            if (throwable instanceof InterruptedException) {
                Thread.currentThread().interrupt();
            }
            throw Throwables.propagate(throwable);
        }
    }
    private static boolean isNullLiteral(Expression entry)
    {
        return entry instanceof Literal && !(entry instanceof NullLiteral);
    }