Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.planner.optimizations;
 
 
 import java.util.Map;
 import java.util.Set;
 
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.collect.Iterables.getOnlyElement;
 
 public class WindowFilterPushDown
         extends PlanOptimizer
 {
     private static final Signature ROW_NUMBER_SIGNATURE = new Signature("row_number"., ImmutableList.<String>of());
 
     @Override
     public PlanNode optimize(PlanNode planSession sessionMap<SymbolTypetypesSymbolAllocator symbolAllocatorPlanNodeIdAllocator idAllocator)
     {
         checkNotNull(plan"plan is null");
         checkNotNull(session"session is null");
         checkNotNull(types"types is null");
         checkNotNull(symbolAllocator"symbolAllocator is null");
         checkNotNull(idAllocator"idAllocator is null");
 
         return PlanRewriter.rewriteWith(new Rewriter(idAllocator), plannull);
     }
 
     private static class Rewriter
             extends PlanRewriter<Constraint>
     {
         private final PlanNodeIdAllocator idAllocator;
 
         private Rewriter(PlanNodeIdAllocator idAllocator)
         {
             this. = checkNotNull(idAllocator"idAllocator is null");
         }
 
         @Override
         public PlanNode visitWindow(WindowNode nodeRewriteContext<Constraintcontext)
         {
             if (canOptimizeWindowFunction(node)) {
                 PlanNode rewrittenSource = context.rewrite(node.getSource(), null);
                 Optional<Integerlimit = getLimit(nodecontext.get());
                 if (node.getOrderBy().isEmpty()) {
                     return new RowNumberNode(.getNextId(),
                             rewrittenSource,
                             node.getPartitionBy(),
                             getOnlyElement(node.getWindowFunctions().keySet()),
                             limit,
                             Optional.empty());
                 }
                 if (limit.isPresent()) {
                     return new TopNRowNumberNode(.getNextId(),
                             rewrittenSource,
                             node.getPartitionBy(),
                             node.getOrderBy(),
                             node.getOrderings(),
                             getOnlyElement(node.getWindowFunctions().keySet()),
                             limit.get(),
                             false,
                             Optional.empty());
                }
            }
            return context.defaultRewrite(node);
        }
        private static Optional<IntegergetLimit(WindowNode nodeConstraint filter)
        {
            if (filter == null || (!filter.getLimit().isPresent() && !filter.getFilterExpression().isPresent())) {
                return Optional.empty();
            }
            if (filter.getLimit().isPresent()) {
                return filter.getLimit();
            }
            if (filterContainsWindowFunctions(nodefilter.getFilterExpression().get()) &&
                    filter.getFilterExpression().get() instanceof ComparisonExpression) {
                Symbol rowNumberSymbol = Iterables.getOnlyElement(node.getWindowFunctions().entrySet()).getKey();
                return WindowLimitExtractor.extract(filter.getFilterExpression().get(), rowNumberSymbol);
            }
            return Optional.empty();
        }
        private static boolean canOptimizeWindowFunction(WindowNode node)
        {
            if (node.getWindowFunctions().size() != 1) {
                return false;
            }
            Symbol rowNumberSymbol = getOnlyElement(node.getWindowFunctions().entrySet()).getKey();
            return isRowNumberSignature(node.getSignatures().get(rowNumberSymbol));
        }
        private static boolean filterContainsWindowFunctions(WindowNode nodeExpression filterPredicate)
        {
            Set<SymbolwindowFunctionSymbols = node.getWindowFunctions().keySet();
            Sets.SetView<SymbolcommonSymbols = Sets.intersection(DependencyExtractor.extractUnique(filterPredicate), windowFunctionSymbols);
            return !commonSymbols.isEmpty();
        }
        @Override
        public PlanNode visitLimit(LimitNode nodeRewriteContext<Constraintcontext)
        {
            // Operators can handle MAX_VALUE rows per page, so do not optimize if count is greater than this value
            if (node.getCount() >= .) {
                return context.defaultRewrite(node);
            }
            Constraint constraint = new Constraint(Optional.of((intnode.getCount()), Optional.empty());
            PlanNode rewrittenSource = context.rewrite(node.getSource(), constraint);
            if (rewrittenSource != node.getSource()) {
                return rewrittenSource;
            }
            return context.defaultRewrite(node);
        }
        @Override
        public PlanNode visitFilter(FilterNode nodeRewriteContext<Constraintcontext)
        {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), new Constraint(Optional.empty(), Optional.of(node.getPredicate())));
            if (rewrittenSource != node.getSource()) {
                if (rewrittenSource instanceof TopNRowNumberNode) {
                    return rewrittenSource;
                }
                return new FilterNode(.getNextId(), rewrittenSourcenode.getPredicate());
            }
            return context.defaultRewrite(node);
        }
    }
    private static boolean isRowNumberSignature(Signature signature)
    {
        return signature.equals();
    }
    private static class Constraint
    {
        private final Optional<Integerlimit;
        private final Optional<ExpressionfilterExpression;
        private Constraint(Optional<IntegerlimitOptional<ExpressionfilterExpression)
        {
            this. = limit;
            this. = filterExpression;
        }
        public Optional<IntegergetLimit()
        {
            return ;
        }
        public Optional<ExpressiongetFilterExpression()
        {
            return ;
        }
    }
    private static final class WindowLimitExtractor
    {
        private WindowLimitExtractor() {}
        public static Optional<Integerextract(Expression expressionSymbol rowNumberSymbol)
        {
            Visitor visitor = new Visitor();
            Long limit = visitor.process(expressionrowNumberSymbol);
            if (limit == null || limit >= .) {
                return Optional.empty();
            }
            return Optional.of(limit.intValue());
        }
        private static class Visitor
                extends DefaultExpressionTraversalVisitor<LongSymbol>
        {
            @Override
            protected Long visitComparisonExpression(ComparisonExpression nodeSymbol rowNumberSymbol)
            {
                Optional<QualifiedNameReferencereference = extractReference(node);
                Optional<Literalliteral = extractLiteral(node);
                if (!reference.isPresent() || !literal.isPresent()) {
                    return null;
                }
                if (!Symbol.fromQualifiedName(reference.get().getName()).equals(rowNumberSymbol)) {
                    return null;
                }
                long literalValue = extractValue(literal.get());
                if (node.getLeft() instanceof QualifiedNameReference && node.getRight() instanceof Literal) {
                    if (node.getType() == ..) {
                        return literalValue;
                    }
                    if (node.getType() == ..) {
                        return literalValue - 1;
                    }
                }
                else if (node.getLeft() instanceof Literal && node.getRight() instanceof QualifiedNameReference) {
                    if (node.getType() == ..) {
                        return literalValue;
                    }
                    if (node.getType() == ..) {
                        return literalValue - 1;
                    }
                }
                return null;
            }
        }
        private static Optional<QualifiedNameReferenceextractReference(ComparisonExpression expression)
        {
            if (expression.getLeft() instanceof QualifiedNameReference) {
                return Optional.of((QualifiedNameReferenceexpression.getLeft());
            }
            if (expression.getRight() instanceof QualifiedNameReference) {
                return Optional.of((QualifiedNameReferenceexpression.getRight());
            }
            return Optional.empty();
        }
        private static Optional<LiteralextractLiteral(ComparisonExpression expression)
        {
            if (expression.getLeft() instanceof Literal) {
                return Optional.of((Literalexpression.getLeft());
            }
            if (expression.getRight() instanceof Literal) {
                return Optional.of((Literalexpression.getRight());
            }
            return Optional.empty();
        }
        private static long extractValue(Literal literal)
        {
            if (literal instanceof DoubleLiteral) {
                return (long) ((DoubleLiteralliteral).getValue();
            }
            if (literal instanceof LongLiteral) {
                return ((LongLiteralliteral).getValue();
            }
            throw new IllegalArgumentException("Row number compared to non numeric literal");
        }
    }
New to GrepCode? Check out our FAQ X