Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package com.facebook.presto.sql.planner;
 
 
 import java.util.Set;
 
 import static com.facebook.presto.sql.tree.ComparisonExpression.Type.EQUAL;
 import static com.facebook.presto.sql.tree.ComparisonExpression.Type.GREATER_THAN;
 import static com.facebook.presto.util.ImmutableCollectors.toImmutableSet;
 import static com.google.common.base.Predicates.not;
 
 public class TestEqualityInference
 {
     @Test
     public void testTransitivity()
             throws Exception
     {
         EqualityInference.Builder builder = new EqualityInference.Builder();
         addEquality("a1""b1"builder);
         addEquality("b1""c1"builder);
         addEquality("d1""c1"builder);
 
         addEquality("a2""b2"builder);
         addEquality("b2""a2"builder);
         addEquality("b2""c2"builder);
         addEquality("d2""b2"builder);
         addEquality("c2""d2"builder);
 
         EqualityInference inference = builder.build();
 
         Assert.assertEquals(
                 inference.rewriteExpression(someExpression("a1""a2"), matchesSymbols("d1""d2")),
                 someExpression("d1""d2"));
 
         Assert.assertEquals(
                 inference.rewriteExpression(someExpression("a1""c1"), matchesSymbols("b1")),
                 someExpression("b1""b1"));
 
         Assert.assertEquals(
                 inference.rewriteExpression(someExpression("a1""a2"), matchesSymbols("b1""d2""c3")),
                 someExpression("b1""d2"));
 
         // Both starting expressions should canonicalize to the same expression
         Assert.assertEquals(
                 inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2""d2")),
                 inference.getScopedCanonical(nameReference("b2"), matchesSymbols("c2""d2")));
         Expression canonical = inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2""d2"));
 
         // Given multiple translatable candidates, should choose the canonical
         Assert.assertEquals(
                 inference.rewriteExpression(someExpression("a2""b2"), matchesSymbols("c2""d2")),
                 someExpression(canonicalcanonical));
     }
 
     @Test
     public void testTriviallyRewritable()
             throws Exception
     {
         EqualityInference.Builder builder = new EqualityInference.Builder();
         Expression expression = builder.build()
                 .rewriteExpression(someExpression("a1""a2"), matchesSymbols("a1""a2"));
 
         Assert.assertEquals(expressionsomeExpression("a1""a2"));
     }
 
     @Test
     public void testUnrewritable()
             throws Exception
     {
         EqualityInference.Builder builder = new EqualityInference.Builder();
         addEquality("a1""b1"builder);
         addEquality("a2""b2"builder);
        EqualityInference inference = builder.build();
        Assert.assertNull(inference.rewriteExpression(someExpression("a1""a2"), matchesSymbols("b1""c1")));
        Assert.assertNull(inference.rewriteExpression(someExpression("c1""c2"), matchesSymbols("a1""a2")));
    }
    @Test
    public void testParseEqualityExpression()
            throws Exception
    {
        EqualityInference inference = new EqualityInference.Builder()
                .addEquality(equals("a1""b1"))
                .addEquality(equals("a1""c1"))
                .addEquality(equals("c1""a1"))
                .build();
        Expression expression = inference.rewriteExpression(someExpression("a1""b1"), matchesSymbols("c1"));
        Assert.assertEquals(expressionsomeExpression("c1""c1"));
    }
    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInvalidEqualityExpression1()
            throws Exception
    {
        new EqualityInference.Builder()
                .addEquality(equals("a1""a1"));
    }
    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInvalidEqualityExpression2()
            throws Exception
    {
        new EqualityInference.Builder()
                .addEquality(someExpression("a1""b1"));
    }
    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInvalidEqualityExpression3()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        addEquality("a1""a1"builder);
    }
    @Test
    public void testExtractInferrableEqualities()
            throws Exception
    {
        EqualityInference inference = new EqualityInference.Builder()
                .extractInferenceCandidates(ExpressionUtils.and(equals("a1""b1"), equals("b1""c1"), someExpression("c1""d1")))
                .build();
        // Able to rewrite to c1 due to equalities
        Assert.assertEquals(nameReference("c1"), inference.rewriteExpression(nameReference("a1"), matchesSymbols("c1")));
        // But not be able to rewrite to d1 which is not connected via equality
        Assert.assertNull(inference.rewriteExpression(nameReference("a1"), matchesSymbols("d1")));
    }
    @Test
    public void testEqualityPartitionGeneration()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        builder.addEquality(nameReference("a1"), nameReference("b1"));
        builder.addEquality(add("a1""a1"), multiply(nameReference("a1"), number(2)));
        builder.addEquality(nameReference("b1"), nameReference("c1"));
        builder.addEquality(add("a1""a1"), nameReference("c1"));
        builder.addEquality(add("a1""b1"), nameReference("c1"));
        EqualityInference inference = builder.build();
        EqualityInference.EqualityPartition emptyScopePartition = inference.generateEqualitiesPartitionedBy(Predicates.<Symbol>alwaysFalse());
        // Cannot generate any scope equalities with no matching symbols
        Assert.assertTrue(emptyScopePartition.getScopeEqualities().isEmpty());
        // All equalities should be represented in the inverse scope
        Assert.assertFalse(emptyScopePartition.getScopeComplementEqualities().isEmpty());
        // There should be no equalities straddling the scope
        Assert.assertTrue(emptyScopePartition.getScopeStraddlingEqualities().isEmpty());
        EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("c1"));
        // There should be equalities in the scope, that only use c1 and are all inferrable equalities
        Assert.assertFalse(equalityPartition.getScopeEqualities().isEmpty());
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(matchesSymbols("c1"))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate()));
        // There should be equalities in the inverse scope, that never use c1 and are all inferrable equalities
        Assert.assertFalse(equalityPartition.getScopeComplementEqualities().isEmpty());
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(matchesSymbols("c1")))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate()));
        // There should be equalities in the straddling scope, that should use both c1 and not c1 symbols
        Assert.assertFalse(equalityPartition.getScopeStraddlingEqualities().isEmpty());
        Assert.assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(matchesSymbols("c1"))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate()));
        // There should be a "full cover" of all of the equalities used
        // THUS, we should be able to plug the generated equalities back in and get an equivalent set of equalities back the next time around
        EqualityInference newInference = new EqualityInference.Builder()
                .addAllEqualities(equalityPartition.getScopeEqualities())
                .addAllEqualities(equalityPartition.getScopeComplementEqualities())
                .addAllEqualities(equalityPartition.getScopeStraddlingEqualities())
                .build();
        EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(matchesSymbols("c1"));
        Assert.assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities()));
        Assert.assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities()));
        Assert.assertEquals(setCopy(equalityPartition.getScopeStraddlingEqualities()), setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
    }
    @Test
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        addEquality("a1""b1"builder);
        addEquality("b1""c1"builder);
        addEquality("c1""d1"builder);
        addEquality("a2""b2"builder);
        addEquality("b2""c2"builder);
        addEquality("c2""d2"builder);
        EqualityInference inference = builder.build();
        // Generating equalities for disjoint groups
        EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(symbolBeginsWith("a""b"));
        // There should be equalities in the scope, that only use a* and b* symbols and are all inferrable equalities
        Assert.assertFalse(equalityPartition.getScopeEqualities().isEmpty());
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(symbolBeginsWith("a""b"))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate()));
        // There should be equalities in the inverse scope, that never use a* and b* symbols and are all inferrable equalities
        Assert.assertFalse(equalityPartition.getScopeComplementEqualities().isEmpty());
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(symbolBeginsWith("a""b")))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate()));
        // There should be equalities in the straddling scope, that should use both c1 and not c1 symbols
        Assert.assertFalse(equalityPartition.getScopeStraddlingEqualities().isEmpty());
        Assert.assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(symbolBeginsWith("a""b"))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate()));
        // Again, there should be a "full cover" of all of the equalities used
        // THUS, we should be able to plug the generated equalities back in and get an equivalent set of equalities back the next time around
        EqualityInference newInference = new EqualityInference.Builder()
                .addAllEqualities(equalityPartition.getScopeEqualities())
                .addAllEqualities(equalityPartition.getScopeComplementEqualities())
                .addAllEqualities(equalityPartition.getScopeStraddlingEqualities())
                .build();
        EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(symbolBeginsWith("a""b"));
        Assert.assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities()));
        Assert.assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities()));
        Assert.assertEquals(setCopy(equalityPartition.getScopeStraddlingEqualities()), setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
    }
    @Test
    public void testSubExpressionRewrites()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        builder.addEquality(nameReference("a1"), add("b""c")); // a1 = b + c
        builder.addEquality(nameReference("a2"), multiply(nameReference("b"), add("b""c"))); // a2 = b * (b + c)
        builder.addEquality(nameReference("a3"), multiply(nameReference("a1"), add("b""c"))); // a3 = a1 * (b + c)
        EqualityInference inference = builder.build();
        // Expression (b + c) should get entirely rewritten as a1
        Assert.assertEquals(inference.rewriteExpression(add("b""c"), symbolBeginsWith("a")), nameReference("a1"));
        // Only the sub-expression (b + c) should get rewritten in terms of a*
        Assert.assertEquals(inference.rewriteExpression(multiply(nameReference("ax"), add("b""c")), symbolBeginsWith("a")), multiply(nameReference("ax"), nameReference("a1")));
        // To be compliant, could rewrite either the whole expression, or just the sub-expression. Rewriting larger expressions are preferred
        Assert.assertEquals(inference.rewriteExpression(multiply(nameReference("a1"), add("b""c")), symbolBeginsWith("a")), nameReference("a3"));
    }
    @Test
    public void testConstantEqualities()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        addEquality("a1""b1"builder);
        addEquality("b1""c1"builder);
        builder.addEquality(nameReference("c1"), number(1));
        EqualityInference inference = builder.build();
        // Should always prefer a constant if available (constant is part of all scopes)
        Assert.assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1""b1")), number(1));
        // All scope equalities should utilize the constant if possible
        EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1""b1"));
        Assert.assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()),
                set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1))));
        Assert.assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()),
                set(set(nameReference("c1"), number(1))));
        // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope
        Assert.assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty());
    }
    private static Predicate<ExpressionmatchesSymbolScope(final Predicate<SymbolsymbolScope)
    {
        return expression -> Iterables.all(DependencyExtractor.extractUnique(expression), symbolScope);
    }
    private static Predicate<ExpressionmatchesStraddlingScope(final Predicate<SymbolsymbolScope)
    {
        return expression -> {
            Set<Symbolsymbols = DependencyExtractor.extractUnique(expression);
            return Iterables.any(symbolssymbolScope) && Iterables.any(symbolsnot(symbolScope));
        };
    }
    private static void addEquality(String symbol1String symbol2EqualityInference.Builder builder)
    {
        builder.addEquality(nameReference(symbol1), nameReference(symbol2));
    }
    private static Expression someExpression(String symbol1String symbol2)
    {
        return someExpression(nameReference(symbol1), nameReference(symbol2));
    }
    private static Expression someExpression(Expression expression1Expression expression2)
    {
        return new ComparisonExpression(expression1expression2);
    }
    private static Expression add(String symbol1String symbol2)
    {
        return add(nameReference(symbol1), nameReference(symbol2));
    }
    private static Expression add(Expression expression1Expression expression2)
    {
        return new ArithmeticBinaryExpression(..expression1expression2);
    }
    private static Expression multiply(String symbol1String symbol2)
    {
        return multiply(nameReference(symbol1), nameReference(symbol2));
    }
    private static Expression multiply(Expression expression1Expression expression2)
    {
        return new ArithmeticBinaryExpression(..expression1expression2);
    }
    private static Expression equals(String symbol1String symbol2)
    {
        return equals(nameReference(symbol1), nameReference(symbol2));
    }
    private static Expression equals(Expression expression1Expression expression2)
    {
        return new ComparisonExpression(expression1expression2);
    }
    private static QualifiedNameReference nameReference(String symbol)
    {
        return new QualifiedNameReference(new Symbol(symbol).toQualifiedName());
    }
    private static LongLiteral number(long number)
    {
        return new LongLiteral(String.valueOf(number));
    }
    private static Predicate<SymbolmatchesSymbols(String... symbols)
    {
        return matchesSymbols(Arrays.asList(symbols));
    }
    private static Predicate<SymbolmatchesSymbols(Collection<Stringsymbols)
    {
        final Set<SymbolsymbolSet = symbols.stream()
                .map(Symbol::new)
                .collect(toImmutableSet());
        return Predicates.in(symbolSet);
    }
    private static Predicate<SymbolsymbolBeginsWith(String... prefixes)
    {
        return symbolBeginsWith(Arrays.asList(prefixes));
    }
    private static Predicate<SymbolsymbolBeginsWith(final Iterable<Stringprefixes)
    {
        return symbol -> {
            for (String prefix : prefixes) {
                if (symbol.getName().startsWith(prefix)) {
                    return true;
                }
            }
            return false;
        };
    }
    private static Set<Set<Expression>> equalitiesAsSets(Iterable<Expressionexpressions)
    {
        ImmutableSet.Builder<Set<Expression>> builder = ImmutableSet.builder();
        for (Expression expression : expressions) {
            builder.add(equalityAsSet(expression));
        }
        return builder.build();
    }
    private static Set<ExpressionequalityAsSet(Expression expression)
    {
        Preconditions.checkArgument(expression instanceof ComparisonExpression);
        ComparisonExpression comparisonExpression = (ComparisonExpressionexpression;
        Preconditions.checkArgument(comparisonExpression.getType() == );
        return ImmutableSet.of(comparisonExpression.getLeft(), comparisonExpression.getRight());
    }
    private static <E> Set<E> set(E... elements)
    {
        return setCopy(Arrays.asList(elements));
    }
    private static <E> Set<E> setCopy(Iterable<E> elements)
    {
        return ImmutableSet.copyOf(elements);
    }
New to GrepCode? Check out our FAQ X