Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.planner.optimizations;
 
 
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Predicates.in;
 import static com.google.common.collect.Iterables.concat;

Removes all computation that does is not referenced transitively from the root of the plan

E.g.,

Output[$0] -> Project[$0 := $1 + $2, $3 = $4 / $5] -> ...

gets rewritten as

Output[$0] -> Project[$0 := $1 + $2] -> ...

 
         extends PlanOptimizer
 {
     @Override
     public PlanNode optimize(PlanNode planSession sessionMap<SymbolTypetypesSymbolAllocator symbolAllocatorPlanNodeIdAllocator idAllocator)
     {
         checkNotNull(plan"plan is null");
         checkNotNull(session"session is null");
         checkNotNull(types"types is null");
         checkNotNull(symbolAllocator"symbolAllocator is null");
         checkNotNull(idAllocator"idAllocator is null");
 
         return PlanRewriter.rewriteWith(new Rewriter(types), plan, ImmutableSet.<Symbol>of());
     }
 
     private static class Rewriter
             extends PlanRewriter<Set<Symbol>>
     {
         private final Map<SymbolTypetypes;
 
         public Rewriter(Map<SymbolTypetypes)
         {
            this. = types;
        }
        @Override
        public PlanNode visitJoin(JoinNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolleftInputsBuilder = ImmutableSet.builder();
            leftInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getLeft));
            if (node.getLeftHashSymbol().isPresent()) {
                leftInputsBuilder.add(node.getLeftHashSymbol().get());
            }
            Set<SymbolleftInputs = leftInputsBuilder.build();
            ImmutableSet.Builder<SymbolrightInputsBuilder = ImmutableSet.builder();
            rightInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight));
            if (node.getRightHashSymbol().isPresent()) {
                rightInputsBuilder.add(node.getRightHashSymbol().get());
            }
            Set<SymbolrightInputs = rightInputsBuilder.build();
            PlanNode left = context.rewrite(node.getLeft(), leftInputs);
            PlanNode right = context.rewrite(node.getRight(), rightInputs);
            return new JoinNode(node.getId(), node.getType(), leftrightnode.getCriteria(), node.getLeftHashSymbol(), node.getRightHashSymbol());
        }
        @Override
        public PlanNode visitSemiJoin(SemiJoinNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolsourceInputsBuilder = ImmutableSet.builder();
            sourceInputsBuilder.addAll(context.get()).add(node.getSourceJoinSymbol());
            if (node.getSourceHashSymbol().isPresent()) {
                sourceInputsBuilder.add(node.getSourceHashSymbol().get());
            }
            Set<SymbolsourceInputs = sourceInputsBuilder.build();
            ImmutableSet.Builder<SymbolfilteringSourceInputBuilder = ImmutableSet.builder();
            filteringSourceInputBuilder.add(node.getFilteringSourceJoinSymbol());
            if (node.getFilteringSourceHashSymbol().isPresent()) {
                filteringSourceInputBuilder.add(node.getFilteringSourceHashSymbol().get());
            }
            Set<SymbolfilteringSourceInputs = filteringSourceInputBuilder.build();
            PlanNode source = context.rewrite(node.getSource(), sourceInputs);
            PlanNode filteringSource = context.rewrite(node.getFilteringSource(), filteringSourceInputs);
            return new SemiJoinNode(node.getId(),
                    source,
                    filteringSource,
                    node.getSourceJoinSymbol(),
                    node.getFilteringSourceJoinSymbol(),
                    node.getSemiJoinOutput(),
                    node.getSourceHashSymbol(),
                    node.getFilteringSourceHashSymbol());
        }
        @Override
        public PlanNode visitIndexJoin(IndexJoinNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolprobeInputsBuilder = ImmutableSet.builder();
            probeInputsBuilder.addAll(context.get())
                    .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getProbe));
            if (node.getProbeHashSymbol().isPresent()) {
                probeInputsBuilder.add(node.getProbeHashSymbol().get());
            }
            Set<SymbolprobeInputs = probeInputsBuilder.build();
            ImmutableSet.Builder<SymbolindexInputBuilder = ImmutableSet.builder();
            indexInputBuilder.addAll(context.get())
                    .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getIndex));
            if (node.getIndexHashSymbol().isPresent()) {
                indexInputBuilder.add(node.getIndexHashSymbol().get());
            }
            Set<SymbolindexInputs = indexInputBuilder.build();
            PlanNode probeSource = context.rewrite(node.getProbeSource(), probeInputs);
            PlanNode indexSource = context.rewrite(node.getIndexSource(), indexInputs);
            return new IndexJoinNode(node.getId(), node.getType(), probeSourceindexSourcenode.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol());
        }
        @Override
        public PlanNode visitIndexSource(IndexSourceNode nodeRewriteContext<Set<Symbol>> context)
        {
            List<SymbolnewOutputSymbols = FluentIterable.from(node.getOutputSymbols())
                    .filter(in(context.get()))
                    .toList();
            Set<SymbolnewLookupSymbols = FluentIterable.from(node.getLookupSymbols())
                    .filter(in(context.get()))
                    .toSet();
            Set<SymbolrequiredAssignmentSymbols = context.get();
            if (!node.getEffectiveTupleDomain().isNone()) {
                Set<SymbolrequiredSymbols = Maps.filterValues(node.getAssignments(), in(node.getEffectiveTupleDomain().getDomains().keySet())).keySet();
                requiredAssignmentSymbols = Sets.union(context.get(), requiredSymbols);
            }
            Map<SymbolColumnHandlenewAssignments = Maps.filterKeys(node.getAssignments(), in(requiredAssignmentSymbols));
            return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), newLookupSymbolsnewOutputSymbolsnewAssignmentsnode.getEffectiveTupleDomain());
        }
        @Override
        public PlanNode visitAggregation(AggregationNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolexpectedInputs = ImmutableSet.<Symbol>builder()
                    .addAll(node.getGroupBy());
            if (node.getHashSymbol().isPresent()) {
                expectedInputs.add(node.getHashSymbol().get());
            }
            ImmutableMap.Builder<SymbolSignaturefunctions = ImmutableMap.builder();
            ImmutableMap.Builder<SymbolFunctionCallfunctionCalls = ImmutableMap.builder();
            ImmutableMap.Builder<SymbolSymbolmasks = ImmutableMap.builder();
            for (Map.Entry<SymbolFunctionCallentry : node.getAggregations().entrySet()) {
                Symbol symbol = entry.getKey();
                if (context.get().contains(symbol)) {
                    FunctionCall call = entry.getValue();
                    expectedInputs.addAll(DependencyExtractor.extractUnique(call));
                    if (node.getMasks().containsKey(symbol)) {
                        expectedInputs.add(node.getMasks().get(symbol));
                        masks.put(symbolnode.getMasks().get(symbol));
                    }
                    functionCalls.put(symbolcall);
                    functions.put(symbolnode.getFunctions().get(symbol));
                }
            }
            if (node.getSampleWeight().isPresent()) {
                expectedInputs.add(node.getSampleWeight().get());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new AggregationNode(node.getId(),
                    source,
                    node.getGroupBy(),
                    functionCalls.build(),
                    functions.build(),
                    masks.build(),
                    node.getSampleWeight(),
                    node.getConfidence(),
                    node.getHashSymbol());
        }
        @Override
        public PlanNode visitWindow(WindowNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolexpectedInputs = ImmutableSet.<Symbol>builder()
                    .addAll(context.get())
                    .addAll(node.getPartitionBy())
                    .addAll(node.getOrderBy());
            if (node.getFrame().getStartValue().isPresent()) {
                expectedInputs.add(node.getFrame().getStartValue().get());
            }
            if (node.getFrame().getEndValue().isPresent()) {
                expectedInputs.add(node.getFrame().getEndValue().get());
            }
            if (node.getHashSymbol().isPresent()) {
                expectedInputs.add(node.getHashSymbol().get());
            }
            ImmutableMap.Builder<SymbolSignaturefunctions = ImmutableMap.builder();
            ImmutableMap.Builder<SymbolFunctionCallfunctionCalls = ImmutableMap.builder();
            for (Map.Entry<SymbolFunctionCallentry : node.getWindowFunctions().entrySet()) {
                Symbol symbol = entry.getKey();
                if (context.get().contains(symbol)) {
                    FunctionCall call = entry.getValue();
                    expectedInputs.addAll(DependencyExtractor.extractUnique(call));
                    functionCalls.put(symbolcall);
                    functions.put(symbolnode.getSignatures().get(symbol));
                }
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new WindowNode(node.getId(), sourcenode.getPartitionBy(), node.getOrderBy(), node.getOrderings(), node.getFrame(), functionCalls.build(), functions.build(), node.getHashSymbol());
        }
        @Override
        public PlanNode visitTableScan(TableScanNode nodeRewriteContext<Set<Symbol>> context)
        {
            Set<SymbolrequiredTableScanOutputs = FluentIterable.from(context.get())
                    .filter(in(ImmutableSet.copyOf(node.getOutputSymbols())))
                    .toSet();
            List<SymbolnewOutputSymbols = FluentIterable.from(node.getOutputSymbols())
                    .filter(in(requiredTableScanOutputs))
                    .toList();
            Set<SymbolrequiredAssignmentSymbols = requiredTableScanOutputs;
            if (!node.getPartitionsDomainSummary().isNone()) {
                Set<SymbolrequiredPartitionDomainSymbols = Maps.filterValues(node.getAssignments(), in(node.getPartitionsDomainSummary().getDomains().keySet())).keySet();
                requiredAssignmentSymbols = Sets.union(requiredTableScanOutputsrequiredPartitionDomainSymbols);
            }
            Map<SymbolColumnHandlenewAssignments = Maps.filterKeys(node.getAssignments(), in(requiredAssignmentSymbols));
            return new TableScanNode(node.getId(), node.getTable(), newOutputSymbolsnewAssignmentsnode.getOriginalConstraint(), node.getSummarizedPartition());
        }
        @Override
        public PlanNode visitFilter(FilterNode nodeRewriteContext<Set<Symbol>> context)
        {
            Set<SymbolexpectedInputs = ImmutableSet.<Symbol>builder()
                    .addAll(DependencyExtractor.extractUnique(node.getPredicate()))
                    .addAll(context.get())
                    .build();
            PlanNode source = context.rewrite(node.getSource(), expectedInputs);
            return new FilterNode(node.getId(), sourcenode.getPredicate());
        }
        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode nodeRewriteContext<Set<Symbol>> context)
        {
            if (!context.get().contains(node.getMarkerSymbol())) {
                return context.rewrite(node.getSource(), context.get());
            }
            ImmutableSet.Builder<SymbolexpectedInputs = ImmutableSet.<Symbol>builder()
                    .addAll(node.getDistinctSymbols())
                    .addAll(context.get());
            if (node.getHashSymbol().isPresent()) {
                expectedInputs.add(node.getHashSymbol().get());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new MarkDistinctNode(node.getId(), sourcenode.getMarkerSymbol(), node.getDistinctSymbols(), node.getHashSymbol());
        }
        @Override
        public PlanNode visitUnnest(UnnestNode nodeRewriteContext<Set<Symbol>> context)
        {
            List<SymbolreplicateSymbols = FluentIterable.from(node.getReplicateSymbols())
                    .filter(in(context.get()))
                    .toList();
            Map<SymbolList<Symbol>> unnestSymbols = node.getUnnestSymbols();
            ImmutableSet.Builder<SymbolexpectedInputs = ImmutableSet.<Symbol>builder()
                    .addAll(replicateSymbols)
                    .addAll(unnestSymbols.keySet());
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new UnnestNode(node.getId(), sourcereplicateSymbolsunnestSymbols);
        }
        @Override
        public PlanNode visitProject(ProjectNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolexpectedInputs = ImmutableSet.builder();
            ImmutableMap.Builder<SymbolExpressionbuilder = ImmutableMap.builder();
            for (int i = 0; i < node.getOutputSymbols().size(); i++) {
                Symbol output = node.getOutputSymbols().get(i);
                Expression expression = node.getExpressions().get(i);
                if (context.get().contains(output)) {
                    expectedInputs.addAll(DependencyExtractor.extractUnique(expression));
                    builder.put(outputexpression);
                }
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new ProjectNode(node.getId(), sourcebuilder.build());
        }
        @Override
        public PlanNode visitOutput(OutputNode nodeRewriteContext<Set<Symbol>> context)
        {
            Set<SymbolexpectedInputs = ImmutableSet.copyOf(node.getOutputSymbols());
            PlanNode source = context.rewrite(node.getSource(), expectedInputs);
            return new OutputNode(node.getId(), sourcenode.getColumnNames(), node.getOutputSymbols());
        }
        @Override
        public PlanNode visitLimit(LimitNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolexpectedInputs = ImmutableSet.<Symbol>builder()
                    .addAll(context.get());
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new LimitNode(node.getId(), sourcenode.getCount());
        }
        @Override
        public PlanNode visitDistinctLimit(DistinctLimitNode nodeRewriteContext<Set<Symbol>> context)
        {
            Set<SymbolexpectedInputs;
            if (node.getHashSymbol().isPresent()) {
                expectedInputs = ImmutableSet.copyOf(concat(node.getOutputSymbols(), ImmutableList.of(node.getHashSymbol().get())));
            }
            else {
                expectedInputs = ImmutableSet.copyOf(node.getOutputSymbols());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs);
            return new DistinctLimitNode(node.getId(), sourcenode.getLimit(), node.getHashSymbol());
        }
        @Override
        public PlanNode visitTopN(TopNNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolexpectedInputs = ImmutableSet.<Symbol>builder()
                    .addAll(context.get())
                    .addAll(node.getOrderBy());
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new TopNNode(node.getId(), sourcenode.getCount(), node.getOrderBy(), node.getOrderings(), node.isPartial());
        }
        @Override
        public PlanNode visitRowNumber(RowNumberNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolinputsBuilder = ImmutableSet.builder();
            ImmutableSet.Builder<SymbolexpectedInputs = inputsBuilder
                    .addAll(context.get())
                    .addAll(node.getPartitionBy());
            if (node.getHashSymbol().isPresent()) {
                inputsBuilder.add(node.getHashSymbol().get());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new RowNumberNode(node.getId(), sourcenode.getPartitionBy(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.getHashSymbol());
        }
        @Override
        public PlanNode visitTopNRowNumber(TopNRowNumberNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolexpectedInputs = ImmutableSet.<Symbol>builder()
                    .addAll(context.get())
                    .addAll(node.getPartitionBy())
                    .addAll(node.getOrderBy());
            if (node.getHashSymbol().isPresent()) {
                expectedInputs.add(node.getHashSymbol().get());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new TopNRowNumberNode(node.getId(),
                    source,
                    node.getPartitionBy(),
                    node.getOrderBy(),
                    node.getOrderings(),
                    node.getRowNumberSymbol(),
                    node.getMaxRowCountPerPartition(),
                    node.isPartial(),
                    node.getHashSymbol());
        }
        @Override
        public PlanNode visitSort(SortNode nodeRewriteContext<Set<Symbol>> context)
        {
            Set<SymbolexpectedInputs = ImmutableSet.copyOf(concat(context.get(), node.getOrderBy()));
            PlanNode source = context.rewrite(node.getSource(), expectedInputs);
            return new SortNode(node.getId(), sourcenode.getOrderBy(), node.getOrderings());
        }
        @Override
        public PlanNode visitTableWriter(TableWriterNode nodeRewriteContext<Set<Symbol>> context)
        {
            ImmutableSet.Builder<SymbolexpectedInputs = ImmutableSet.<Symbol>builder()
                    .addAll(node.getColumns());
            if (node.getSampleWeightSymbol().isPresent()) {
                expectedInputs.add(node.getSampleWeightSymbol().get());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new TableWriterNode(node.getId(), sourcenode.getTarget(), node.getColumns(), node.getColumnNames(), node.getOutputSymbols(), node.getSampleWeightSymbol());
        }
        @Override
        public PlanNode visitUnion(UnionNode nodeRewriteContext<Set<Symbol>> context)
        {
            // Find out which output symbols we need to keep
            ImmutableListMultimap.Builder<SymbolSymbolrewrittenSymbolMappingBuilder = ImmutableListMultimap.builder();
            for (Symbol symbol : node.getOutputSymbols()) {
                if (context.get().contains(symbol)) {
                    rewrittenSymbolMappingBuilder.putAll(symbolnode.getSymbolMapping().get(symbol));
                }
            }
            ListMultimap<SymbolSymbolrewrittenSymbolMapping = rewrittenSymbolMappingBuilder.build();
            // Find the corresponding input symbol to the remaining output symbols and prune the subplans
            ImmutableList.Builder<PlanNoderewrittenSubPlans = ImmutableList.builder();
            for (int i = 0; i < node.getSources().size(); i++) {
                ImmutableSet.Builder<SymbolexpectedInputSymbols = ImmutableSet.builder();
                for (Collection<Symbolsymbols : rewrittenSymbolMapping.asMap().values()) {
                    expectedInputSymbols.add(Iterables.get(symbolsi));
                }
                rewrittenSubPlans.add(context.rewrite(node.getSources().get(i), expectedInputSymbols.build()));
            }
            return new UnionNode(node.getId(), rewrittenSubPlans.build(), rewrittenSymbolMapping);
        }
    }
New to GrepCode? Check out our FAQ X