Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.util;
 
 
 import java.util.List;
 import java.util.Map;
 
 import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
 import static com.google.common.collect.Maps.immutableEnumMap;
 import static java.lang.String.format;
 
 public final class GraphvizPrinter
 {
     private enum NodeType
     {
         EXCHANGE,
         AGGREGATE,
         FILTER,
         PROJECT,
         TOPN,
         OUTPUT,
         LIMIT,
         TABLESCAN,
         VALUES,
         JOIN,
         SINK,
         WINDOW,
         UNION,
         SORT,
         SAMPLE,
         MARK_DISTINCT,
         TABLE_WRITER,
         TABLE_COMMIT,
         INDEX_SOURCE,
         UNNEST
     }
 
     private static final Map<NodeTypeStringNODE_COLORS = immutableEnumMap(ImmutableMap.<NodeTypeString>builder()
             .put(."gold")
             .put(."chartreuse3")
             .put(."yellow")
             .put(."bisque")
             .put(."darksalmon")
             .put(."white")
            .put(."gray83")
            .put(."deepskyblue")
            .put(."deepskyblue")
            .put(."orange")
            .put(."aliceblue")
            .put(."indianred1")
            .put(."darkolivegreen4")
            .put(."turquoise4")
            .put(."violet")
            .put(."cyan")
            .put(."hotpink")
            .put(."dodgerblue3")
            .put(."crimson")
            .put(."goldenrod4")
            .build());
    static {
        Preconditions.checkState(.size() == NodeType.values().length);
    }
    private GraphvizPrinter() {}
    public static String printLogical(List<PlanFragmentfragments)
    {
        Map<PlanFragmentIdPlanFragmentfragmentsById = Maps.uniqueIndex(fragments, PlanFragment::getId);
        PlanNodeIdGenerator idGenerator = new PlanNodeIdGenerator();
        StringBuilder output = new StringBuilder();
        output.append("digraph logical_plan {\n");
        for (PlanFragment fragment : fragments) {
            printFragmentNodes(outputfragmentidGenerator);
        }
        for (PlanFragment fragment : fragments) {
            fragment.getRoot().accept(new EdgePrinter(outputfragmentsByIdidGenerator), null);
        }
        output.append("}\n");
        return output.toString();
    }
    public static String printDistributed(SubPlan plan)
    {
        List<PlanFragmentfragments = plan.getAllFragments();
        Map<PlanFragmentIdPlanFragmentfragmentsById = Maps.uniqueIndex(fragments, PlanFragment::getId);
        PlanNodeIdGenerator idGenerator = new PlanNodeIdGenerator();
        StringBuilder output = new StringBuilder();
        output.append("digraph distributed_plan {\n");
        printSubPlan(planfragmentsByIdidGeneratoroutput);
        output.append("}\n");
        return output.toString();
    }
    private static void printSubPlan(SubPlan planMap<PlanFragmentIdPlanFragmentfragmentsByIdPlanNodeIdGenerator idGeneratorStringBuilder output)
    {
        PlanFragment fragment = plan.getFragment();
        printFragmentNodes(outputfragmentidGenerator);
        fragment.getRoot().accept(new EdgePrinter(outputfragmentsByIdidGenerator), null);
        for (SubPlan child : plan.getChildren()) {
            printSubPlan(childfragmentsByIdidGeneratoroutput);
        }
    }
    private static void printFragmentNodes(StringBuilder outputPlanFragment fragmentPlanNodeIdGenerator idGenerator)
    {
        String clusterId = "cluster_" + fragment.getId();
        output.append("subgraph ")
                .append(clusterId)
                .append(" {")
                .append('\n');
        output.append(format("label = \"%s\""fragment.getDistribution()))
                .append('\n');
        PlanNode plan = fragment.getRoot();
        plan.accept(new NodePrinter(outputidGenerator), null);
        output.append("}")
                .append('\n');
    }
    private static class NodePrinter
            extends PlanVisitor<VoidVoid>
    {
        private static final int MAX_NAME_WIDTH = 100;
        private final StringBuilder output;
        private final PlanNodeIdGenerator idGenerator;
        public NodePrinter(StringBuilder outputPlanNodeIdGenerator idGenerator)
        {
            this. = output;
            this. = idGenerator;
        }
        @Override
        protected Void visitPlan(PlanNode nodeVoid context)
        {
            throw new UnsupportedOperationException(format("Node %s does not have a Graphviz visitor"node.getClass().getName()));
        }
        @Override
        public Void visitTableWriter(TableWriterNode nodeVoid context)
        {
            List<Stringcolumns = new ArrayList<>();
            for (int i = 0; i < node.getColumnNames().size(); i++) {
                columns.add(node.getColumnNames().get(i) + " := " + node.getColumns().get(i));
            }
            printNode(nodeformat("TableWriter[%s]", Joiner.on(", ").join(columns)), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitTableCommit(TableCommitNode nodeVoid context)
        {
            printNode(nodeformat("TableCommit[%s]", Joiner.on(", ").join(node.getOutputSymbols())), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitSample(SampleNode nodeVoid context)
        {
            printNode(nodeformat("Sample[type=%s, ratio=%f, rescaled=%s]"node.getSampleType(), node.getSampleRatio(), node.isRescaled()), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitSort(SortNode nodeVoid context)
        {
            printNode(nodeformat("Sort[%s]", Joiner.on(", ").join(node.getOrderBy())), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitMarkDistinct(MarkDistinctNode nodeVoid context)
        {
            printNode(nodeformat("MarkDistinct[%s]"node.getMarkerSymbol()), format("%s => %s"node.getDistinctSymbols(), node.getMarkerSymbol()), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitWindow(WindowNode nodeVoid context)
        {
            printNode(node"Window"format("partition by = %s|order by = %s", Joiner.on(", ").join(node.getPartitionBy()), Joiner.on(", ").join(node.getOrderBy())), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitRowNumber(RowNumberNode nodeVoid context)
        {
            printNode(node,
                    "RowNumber",
                    format("partition by = %s", Joiner.on(", ").join(node.getPartitionBy())),
                    .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitTopNRowNumber(TopNRowNumberNode nodeVoid context)
        {
            printNode(node,
                    "TopNRowNumber",
                    format("partition by = %s|order by = %s|n = %s", Joiner.on(", ").join(node.getPartitionBy()), Joiner.on(", ").join(node.getOrderBy()), node.getMaxRowCountPerPartition()),
                    .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitUnion(UnionNode nodeVoid context)
        {
            printNode(node"Union".get(.));
            for (PlanNode planNode : node.getSources()) {
                planNode.accept(thiscontext);
            }
            return null;
        }
        @Override
        public Void visitRemoteSource(RemoteSourceNode nodeVoid context)
        {
            printNode(node"Exchange 1:N".get(.));
            return null;
        }
        @Override
        public Void visitExchange(ExchangeNode nodeVoid context)
        {
            List<Symbolsymbols = node.getOutputSymbols();
            if (node.getType() == ) {
                symbols = node.getPartitionKeys();
            }
            String columns = Joiner.on(", ").join(symbols);
            printNode(nodeformat("ExchangeNode[%s]"node.getType()), columns.get(.));
            for (PlanNode planNode : node.getSources()) {
                planNode.accept(thiscontext);
            }
            return null;
        }
        @Override
        public Void visitAggregation(AggregationNode nodeVoid context)
        {
            StringBuilder builder = new StringBuilder();
            for (Map.Entry<SymbolFunctionCallentry : node.getAggregations().entrySet()) {
                if (node.getMasks().containsKey(entry.getKey())) {
                    builder.append(format("%s := %s (mask = %s)\\n"entry.getKey(), entry.getValue(), node.getMasks().get(entry.getKey())));
                }
                else {
                    builder.append(format("%s := %s\\n"entry.getKey(), entry.getValue()));
                }
            }
            printNode(nodeformat("Aggregate[%s]"node.getStep()), builder.toString(), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitFilter(FilterNode nodeVoid context)
        {
            String expression = node.getPredicate().toString();
            printNode(node"Filter"expression.get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitProject(ProjectNode nodeVoid context)
        {
            StringBuilder builder = new StringBuilder();
            for (Map.Entry<SymbolExpressionentry : node.getAssignments().entrySet()) {
                if ((entry.getValue() instanceof QualifiedNameReference) &&
                        ((QualifiedNameReferenceentry.getValue()).getName().equals(entry.getKey().toQualifiedName())) {
                    // skip identity assignments
                    continue;
                }
                builder.append(format("%s := %s\\n"entry.getKey(), entry.getValue()));
            }
            printNode(node"Project"builder.toString(), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitUnnest(UnnestNode nodeVoid context)
        {
            if (node.getOrdinalitySymbol() == null) {
                printNode(nodeformat("Unnest[%s]"node.getUnnestSymbols().keySet()), .get(.));
            }
            else {
                printNode(nodeformat("Unnest[%s (ordinality)]"node.getUnnestSymbols().keySet()), .get(.));
            }
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitTopN(final TopNNode nodeVoid context)
        {
            Iterable<Stringkeys = Iterables.transform(node.getOrderBy(), input -> input + " " + node.getOrderings().get(input));
            printNode(nodeformat("TopN[%s]"node.getCount()), Joiner.on(", ").join(keys), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitOutput(OutputNode nodeVoid context)
        {
            String columns = getColumns(node);
            printNode(nodeformat("Output[%s]"columns), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitDistinctLimit(DistinctLimitNode nodeVoid context)
        {
            printNode(nodeformat("DistinctLimit[%s]"node.getLimit()), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitLimit(LimitNode nodeVoid context)
        {
            printNode(nodeformat("Limit[%s]"node.getCount()), .get(.));
            return node.getSource().accept(thiscontext);
        }
        @Override
        public Void visitTableScan(TableScanNode nodeVoid context)
        {
            printNode(nodeformat("TableScan[%s]"node.getTable()), format("original constraint=%s"node.getOriginalConstraint()), .get(.));
            return null;
        }
        @Override
        public Void visitValues(ValuesNode nodeVoid context)
        {
            printNode(node"Values".get(.));
            return null;
        }
        @Override
        public Void visitJoin(JoinNode nodeVoid context)
        {
            List<ExpressionjoinExpressions = new ArrayList<>();
            for (JoinNode.EquiJoinClause clause : node.getCriteria()) {
                joinExpressions.add(new ComparisonExpression(..,
                        new QualifiedNameReference(clause.getLeft().toQualifiedName()),
                        new QualifiedNameReference(clause.getRight().toQualifiedName())));
            }
            String criteria = Joiner.on(" AND ").join(joinExpressions);
            printNode(nodenode.getType().getJoinLabel(), criteria.get(.));
            node.getLeft().accept(thiscontext);
            node.getRight().accept(thiscontext);
            return null;
        }
        @Override
        public Void visitSemiJoin(SemiJoinNode nodeVoid context)
        {
            printNode(node"SemiJoin"format("%s = %s"node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol()), .get(.));
            node.getSource().accept(thiscontext);
            node.getFilteringSource().accept(thiscontext);
            return null;
        }
        @Override
        public Void visitIndexSource(IndexSourceNode nodeVoid context)
        {
            printNode(nodeformat("IndexSource[%s]"node.getIndexHandle()), .get(.));
            return null;
        }
        @Override
        public Void visitIndexJoin(IndexJoinNode nodeVoid context)
        {
            List<ExpressionjoinExpressions = new ArrayList<>();
            for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) {
                joinExpressions.add(new ComparisonExpression(..,
                        new QualifiedNameReference(clause.getProbe().toQualifiedName()),
                        new QualifiedNameReference(clause.getIndex().toQualifiedName())));
            }
            String criteria = Joiner.on(" AND ").join(joinExpressions);
            String joinLabel = format("%sIndexJoin"node.getType().getJoinLabel());
            printNode(nodejoinLabelcriteria.get(.));
            node.getProbeSource().accept(thiscontext);
            node.getIndexSource().accept(thiscontext);
            return null;
        }
        private void printNode(PlanNode nodeString labelString color)
        {
            String nodeId = .getNodeId(node);
            label = escapeSpecialCharacters(label);
            .append(nodeId)
                    .append(format("[label=\"{%s}\", style=\"rounded, filled\", shape=record, fillcolor=%s]"labelcolor))
                    .append(';')
                    .append('\n');
        }
        private void printNode(PlanNode nodeString labelString detailsString color)
        {
            if (details.isEmpty()) {
                printNode(nodelabelcolor);
            }
            else {
                String nodeId = .getNodeId(node);
                label = escapeSpecialCharacters(label);
                details = escapeSpecialCharacters(details);
                .append(nodeId)
                        .append(format("[label=\"{%s|%s}\", style=\"rounded, filled\", shape=record, fillcolor=%s]"labeldetailscolor))
                        .append(';')
                        .append('\n');
            }
        }
        private static String getColumns(OutputNode node)
        {
            Iterator<StringcolumnNames = node.getColumnNames().iterator();
            String columns = "";
            int nameWidth = 0;
            while (columnNames.hasNext()) {
                String columnName = columnNames.next();
                columns += columnName;
                nameWidth += columnName.length();
                if (columnNames.hasNext()) {
                    columns += ", ";
                }
                if (nameWidth >= ) {
                    columns += "\\n";
                    nameWidth = 0;
                }
            }
            return columns;
        }

        
Escape characters that are special to graphviz.
        private static String escapeSpecialCharacters(String label)
        {
            return label
                    .replace("<""\\<")
                    .replace(">""\\>")
                    .replace("\"""\\\"");
        }
    }
    private static class EdgePrinter
            extends PlanVisitor<VoidVoid>
    {
        private final StringBuilder output;
        private final Map<PlanFragmentIdPlanFragmentfragmentsById;
        private final PlanNodeIdGenerator idGenerator;
        public EdgePrinter(StringBuilder outputMap<PlanFragmentIdPlanFragmentfragmentsByIdPlanNodeIdGenerator idGenerator)
        {
            this. = output;
            this. = ImmutableMap.copyOf(fragmentsById);
            this. = idGenerator;
        }
        @Override
        protected Void visitPlan(PlanNode nodeVoid context)
        {
            for (PlanNode child : node.getSources()) {
                printEdge(nodechild);
                child.accept(thiscontext);
            }
            return null;
        }
        @Override
        public Void visitRemoteSource(RemoteSourceNode nodeVoid context)
        {
            for (PlanFragmentId planFragmentId : node.getSourceFragmentIds()) {
                PlanFragment target = .get(planFragmentId);
                printEdge(nodetarget.getRoot());
            }
            return null;
        }
        private void printEdge(PlanNode fromPlanNode to)
        {
            String fromId = .getNodeId(from);
            String toId = .getNodeId(to);
            .append(fromId)
                    .append(" -> ")
                    .append(toId)
                    .append(';')
                    .append('\n');
        }
    }
    private static class PlanNodeIdGenerator
    {
        private final Map<PlanNodeIntegerplanNodeIds;
        private int idCount;
        public PlanNodeIdGenerator()
        {
             = new HashMap<>();
        }
        public String getNodeId(PlanNode from)
        {
            int nodeId;
            if (.containsKey(from)) {
                nodeId = .get(from);
            }
            else {
                ++;
                .put(from);
                nodeId = ;
            }
            return ("plannode_" + nodeId);
        }
    }
New to GrepCode? Check out our FAQ X