Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.planner;
 
 
 import java.util.Map;
 import java.util.Set;
 
 import static com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer.IndexKeyTracer;
 import static com.google.common.base.Preconditions.checkArgument;

Ensures that all dependencies (i.e., symbols in expressions) for a plan node are provided by its source nodes
 
 public final class PlanSanityChecker
 {
     private PlanSanityChecker() {}
 
     public static void validate(PlanNode plan)
     {
         plan.accept(new Visitor(), null);
     }
 
     private static class Visitor
             extends PlanVisitor<VoidVoid>
     {
         private final Map<PlanNodeIdPlanNodenodesById = new HashMap<>();
 
         @Override
         protected Void visitPlan(PlanNode nodeVoid context)
         {
             throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName());
         }
 
         @Override
         public Void visitAggregation(AggregationNode nodeVoid context)
         {
             PlanNode source = node.getSource();
             source.accept(thiscontext); // visit child
 
             verifyUniqueId(node);
 
             Set<Symbolinputs = ImmutableSet.copyOf(source.getOutputSymbols());
             checkDependencies(inputsnode.getGroupBy(), "Invalid node. Group by symbols (%s) not in source plan output (%s)"node.getGroupBy(), node.getSource().getOutputSymbols());
 
             if (node.getSampleWeight().isPresent()) {
                 checkArgument(inputs.contains(node.getSampleWeight().get()), "Invalid node. Sample weight symbol (%s) is not in source plan output (%s)"node.getSampleWeight().get(), node.getSource().getOutputSymbols());
             }
 
             for (FunctionCall call : node.getAggregations().values()) {
                 Set<Symboldependencies = DependencyExtractor.extractUnique(call);
                 checkDependencies(inputsdependencies"Invalid node. Aggregation dependencies (%s) not in source plan output (%s)"dependenciesnode.getSource().getOutputSymbols());
             }
            return null;
        }
        @Override
        public Void visitMarkDistinct(MarkDistinctNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            checkDependencies(source.getOutputSymbols(), node.getDistinctSymbols(), "Invalid node. Mark distinct symbols (%s) not in source plan output (%s)"node.getDistinctSymbols(), source.getOutputSymbols());
            return null;
        }
        @Override
        public Void visitWindow(WindowNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            Set<Symbolinputs = ImmutableSet.copyOf(source.getOutputSymbols());
            checkDependencies(inputsnode.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)"node.getPartitionBy(), node.getSource().getOutputSymbols());
            checkDependencies(inputsnode.getOrderBy(), "Invalid node. Order by symbols (%s) not in source plan output (%s)"node.getOrderBy(), node.getSource().getOutputSymbols());
            ImmutableList.Builder<Symbolbounds = ImmutableList.builder();
            if (node.getFrame().getStartValue().isPresent()) {
                bounds.add(node.getFrame().getStartValue().get());
            }
            if (node.getFrame().getEndValue().isPresent()) {
                bounds.add(node.getFrame().getEndValue().get());
            }
            checkDependencies(inputsbounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)"bounds.build(), node.getSource().getOutputSymbols());
            for (FunctionCall call : node.getWindowFunctions().values()) {
                Set<Symboldependencies = DependencyExtractor.extractUnique(call);
                checkDependencies(inputsdependencies"Invalid node. Window function dependencies (%s) not in source plan output (%s)"dependenciesnode.getSource().getOutputSymbols());
            }
            return null;
        }
        @Override
        public Void visitTopNRowNumber(TopNRowNumberNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            Set<Symbolinputs = ImmutableSet.copyOf(source.getOutputSymbols());
            checkDependencies(inputsnode.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)"node.getPartitionBy(), node.getSource().getOutputSymbols());
            checkDependencies(inputsnode.getOrderBy(), "Invalid node. Order by symbols (%s) not in source plan output (%s)"node.getOrderBy(), node.getSource().getOutputSymbols());
            return null;
        }
        @Override
        public Void visitRowNumber(RowNumberNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            checkDependencies(source.getOutputSymbols(), node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)"node.getPartitionBy(), node.getSource().getOutputSymbols());
            return null;
        }
        @Override
        public Void visitFilter(FilterNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            Set<Symbolinputs = ImmutableSet.copyOf(source.getOutputSymbols());
            checkDependencies(inputsnode.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)"node.getOutputSymbols(), node.getSource().getOutputSymbols());
            Set<Symboldependencies = DependencyExtractor.extractUnique(node.getPredicate());
            checkDependencies(inputsdependencies"Invalid node. Predicate dependencies (%s) not in source plan output (%s)"dependenciesnode.getSource().getOutputSymbols());
            return null;
        }
        @Override
        public Void visitSample(SampleNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            return null;
        }
        @Override
        public Void visitProject(ProjectNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            Set<Symbolinputs = ImmutableSet.copyOf(source.getOutputSymbols());
            for (Expression expression : node.getExpressions()) {
                Set<Symboldependencies = DependencyExtractor.extractUnique(expression);
                checkDependencies(inputsdependencies"Invalid node. Expression dependencies (%s) not in source plan output (%s)"dependenciesinputs);
            }
            return null;
        }
        @Override
        public Void visitTopN(TopNNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            Set<Symbolinputs = ImmutableSet.copyOf(source.getOutputSymbols());
            checkDependencies(inputsnode.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)"node.getOutputSymbols(), node.getSource().getOutputSymbols());
            checkDependencies(inputsnode.getOrderBy(),
                    "Invalid node. Order by dependencies (%s) not in source plan output (%s)",
                    node.getOrderBy(),
                    node.getSource().getOutputSymbols());
            return null;
        }
        @Override
        public Void visitSort(SortNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            Set<Symbolinputs = ImmutableSet.copyOf(source.getOutputSymbols());
            checkDependencies(inputsnode.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)"node.getOutputSymbols(), node.getSource().getOutputSymbols());
            checkDependencies(inputsnode.getOrderBy(), "Invalid node. Order by dependencies (%s) not in source plan output (%s)"node.getOrderBy(), node.getSource().getOutputSymbols());
            return null;
        }
        @Override
        public Void visitOutput(OutputNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            checkDependencies(source.getOutputSymbols(), node.getOutputSymbols(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)"node.getOutputSymbols(), source.getOutputSymbols());
            return null;
        }
        @Override
        public Void visitLimit(LimitNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            return null;
        }
        @Override
        public Void visitDistinctLimit(DistinctLimitNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            return null;
        }
        @Override
        public Void visitJoin(JoinNode nodeVoid context)
        {
            node.getLeft().accept(thiscontext);
            node.getRight().accept(thiscontext);
            verifyUniqueId(node);
            Set<SymbolleftInputs = ImmutableSet.copyOf(node.getLeft().getOutputSymbols());
            Set<SymbolrightInputs = ImmutableSet.copyOf(node.getRight().getOutputSymbols());
            for (JoinNode.EquiJoinClause clause : node.getCriteria()) {
                checkArgument(leftInputs.contains(clause.getLeft()), "Symbol from join clause (%s) not in left source (%s)"clause.getLeft(), node.getLeft().getOutputSymbols());
                checkArgument(rightInputs.contains(clause.getRight()), "Symbol from join clause (%s) not in right source (%s)"clause.getRight(), node.getRight().getOutputSymbols());
            }
            return null;
        }
        @Override
        public Void visitSemiJoin(SemiJoinNode nodeVoid context)
        {
            node.getSource().accept(thiscontext);
            node.getFilteringSource().accept(thiscontext);
            verifyUniqueId(node);
            checkArgument(node.getSource().getOutputSymbols().contains(node.getSourceJoinSymbol()), "Symbol from semi join clause (%s) not in source (%s)"node.getSourceJoinSymbol(), node.getSource().getOutputSymbols());
            checkArgument(node.getFilteringSource().getOutputSymbols().contains(node.getFilteringSourceJoinSymbol()), "Symbol from semi join clause (%s) not in filtering source (%s)"node.getSourceJoinSymbol(), node.getFilteringSource().getOutputSymbols());
            Set<Symboloutputs = ImmutableSet.copyOf(node.getOutputSymbols());
            checkArgument(outputs.containsAll(node.getSource().getOutputSymbols()), "Semi join output symbols (%s) must contain all of the source symbols (%s)"node.getOutputSymbols(), node.getSource().getOutputSymbols());
            checkArgument(outputs.contains(node.getSemiJoinOutput()),
                    "Semi join output symbols (%s) must contain join result (%s)",
                    node.getOutputSymbols(),
                    node.getSemiJoinOutput());
            return null;
        }
        @Override
        public Void visitIndexJoin(IndexJoinNode nodeVoid context)
        {
            node.getProbeSource().accept(thiscontext);
            node.getIndexSource().accept(thiscontext);
            verifyUniqueId(node);
            Set<SymbolprobeInputs = ImmutableSet.copyOf(node.getProbeSource().getOutputSymbols());
            Set<SymbolindexSourceInputs = ImmutableSet.copyOf(node.getIndexSource().getOutputSymbols());
            for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) {
                checkArgument(probeInputs.contains(clause.getProbe()), "Probe symbol from index join clause (%s) not in probe source (%s)"clause.getProbe(), node.getProbeSource().getOutputSymbols());
                checkArgument(indexSourceInputs.contains(clause.getIndex()), "Index symbol from index join clause (%s) not in index source (%s)"clause.getIndex(), node.getIndexSource().getOutputSymbols());
            }
            Set<SymbollookupSymbols = FluentIterable.from(node.getCriteria())
                    .transform(IndexJoinNode.EquiJoinClause::getIndex)
                    .toSet();
            Map<SymbolSymboltrace = IndexKeyTracer.trace(node.getIndexSource(), lookupSymbols);
            checkArgument(!trace.isEmpty() && lookupSymbols.containsAll(trace.keySet()),
                    "Index lookup symbols are not traceable to index source: %s",
                    lookupSymbols);
            return null;
        }
        @Override
        public Void visitIndexSource(IndexSourceNode nodeVoid context)
        {
            verifyUniqueId(node);
            checkDependencies(node.getOutputSymbols(), node.getLookupSymbols(), "Lookup symbols must be part of output symbols");
            checkDependencies(node.getAssignments().keySet(), node.getOutputSymbols(), "Assignments must contain mappings for output symbols");
            return null;
        }
        @Override
        public Void visitTableScan(TableScanNode nodeVoid context)
        {
            verifyUniqueId(node);
            checkArgument(node.getAssignments().keySet().containsAll(node.getOutputSymbols()), "Assignments must contain mappings for output symbols");
            return null;
        }
        @Override
        public Void visitValues(ValuesNode nodeVoid context)
        {
            verifyUniqueId(node);
            return null;
        }
        @Override
        public Void visitUnnest(UnnestNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext);
            verifyUniqueId(node);
            Set<Symbolrequired = ImmutableSet.<Symbol>builder()
                    .addAll(node.getReplicateSymbols())
                    .addAll(node.getUnnestSymbols().keySet())
                    .build();
            checkDependencies(source.getOutputSymbols(), required"Invalid node. Dependencies (%s) not in source plan output (%s)"requiredsource.getOutputSymbols());
            return null;
        }
        @Override
        public Void visitRemoteSource(RemoteSourceNode nodeVoid context)
        {
            verifyUniqueId(node);
            return null;
        }
        @Override
        public Void visitExchange(ExchangeNode nodeVoid context)
        {
            verifyUniqueId(node);
            return null;
        }
        @Override
        public Void visitTableWriter(TableWriterNode nodeVoid context)
        {
            PlanNode source = node.getSource();
            source.accept(thiscontext); // visit child
            verifyUniqueId(node);
            if (node.getSampleWeightSymbol().isPresent()) {
                checkArgument(source.getOutputSymbols().contains(node.getSampleWeightSymbol().get()), "Invalid node. Sample weight symbol (%s) is not in source plan output (%s)"node.getSampleWeightSymbol().get(), node.getSource().getOutputSymbols());
            }
            return null;
        }
        @Override
        public Void visitTableCommit(TableCommitNode nodeVoid context)
        {
            node.getSource().accept(thiscontext); // visit child
            verifyUniqueId(node);
            return null;
        }
        @Override
        public Void visitUnion(UnionNode nodeVoid context)
        {
            for (int i = 0; i < node.getSources().size(); i++) {
                PlanNode subplan = node.getSources().get(i);
                checkDependencies(subplan.getOutputSymbols(), node.sourceOutputLayout(i), "UNION subplan must provide all of the necessary symbols");
                subplan.accept(thiscontext); // visit child
            }
            verifyUniqueId(node);
            return null;
        }
        private void verifyUniqueId(PlanNode node)
        {
            PlanNodeId id = node.getId();
            checkArgument(!.containsKey(id), "Duplicate node id found %s between %s and %s"node.getId(), node.get(id));
            .put(idnode);
        }
    }
    private static void checkDependencies(Collection<SymbolinputsCollection<SymbolrequiredString messageObject... parameters)
    {
        checkArgument(ImmutableSet.copyOf(inputs).containsAll(required), messageparameters);
    }
New to GrepCode? Check out our FAQ X