Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.planner.optimizations;
 
 
 import java.util.Map;
 
 import static com.google.common.base.Preconditions.checkNotNull;
 
 public class CountConstantOptimizer
         extends PlanOptimizer
 {
     @Override
     public PlanNode optimize(PlanNode planSession sessionMap<SymbolTypetypesSymbolAllocator symbolAllocatorPlanNodeIdAllocator idAllocator)
     {
         checkNotNull(plan"plan is null");
         checkNotNull(session"session is null");
         checkNotNull(types"types is null");
         checkNotNull(symbolAllocator"symbolAllocator is null");
         checkNotNull(idAllocator"idAllocator is null");
 
         return PlanRewriter.rewriteWith(new Rewriter(), plan);
     }
 
     private static class Rewriter
             extends PlanRewriter<Void>
     {
         @Override
         public PlanNode visitAggregation(AggregationNode nodeRewriteContext<Voidcontext)
         {
             Map<SymbolFunctionCallaggregations = new LinkedHashMap<>(node.getAggregations());
             Map<SymbolSignaturefunctions = new LinkedHashMap<>(node.getFunctions());
 
             PlanNode source = context.rewrite(node.getSource());
             if (source instanceof ProjectNode) {
                 ProjectNode projectNode = (ProjectNodesource;
                 for (Entry<SymbolFunctionCallentry : node.getAggregations().entrySet()) {
                     Symbol symbol = entry.getKey();
                     FunctionCall functionCall = entry.getValue();
                     Signature signature = node.getFunctions().get(symbol);
                     if (isCountConstant(projectNodefunctionCallsignature)) {
                         aggregations.put(symbolnew FunctionCall(functionCall.getName(), nullfunctionCall.isDistinct(), ImmutableList.<Expression>of()));
                         functions.put(symbolnew Signature("count".));
                     }
                 }
             }
 
             return new AggregationNode(
                     node.getId(),
                     source,
                     node.getGroupBy(),
                     aggregations,
                     functions,
                     node.getMasks(),
                     node.getStep(),
                     node.getSampleWeight(),
                     node.getConfidence(),
                     node.getHashSymbol());
         }
 
         public static boolean isCountConstant(ProjectNode projectNodeFunctionCall functionCallSignature signature)
         {
             if (!"count".equals(signature.getName()) ||
                     signature.getArgumentTypes().size() != 1 ||
                     !signature.getReturnType().equals(.)) {
                 return false;
             }
 
            Expression argument = functionCall.getArguments().get(0);
            if (argument instanceof Literal) {
                return true;
            }
            if (argument instanceof QualifiedNameReference) {
                QualifiedNameReference qualifiedNameReference = (QualifiedNameReferenceargument;
                QualifiedName qualifiedName = qualifiedNameReference.getName();
                Symbol argumentSymbol = Symbol.fromQualifiedName(qualifiedName);
                Expression argumentExpression = projectNode.getAssignments().get(argumentSymbol);
                return (argumentExpression instanceof Literal) && (!(argumentExpression instanceof NullLiteral));
            }
            return false;
        }
    }
New to GrepCode? Check out our FAQ X