Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.FunctionInvocation;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.ast.Var;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
Expand Down Expand Up @@ -69,6 +70,12 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map<String,
return new ValDerivationNode(var, null);
}

if (exp instanceof FunctionInvocation) {
Expression value = subs.get(exp.toString());
if (value != null)
return new ValDerivationNode(value.clone(), new VarDerivationNode(exp.toString()));
}

// lift unary origin
if (exp instanceof UnaryExpression unary) {
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs, varOrigins);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.FunctionInvocation;
import liquidjava.rj_language.ast.Var;

public class VariableResolver {
Expand Down Expand Up @@ -45,33 +46,50 @@ private static void resolveRecursive(Expression exp, Map<String, Expression> map
if ("&&".equals(op)) {
resolveRecursive(be.getFirstOperand(), map);
resolveRecursive(be.getSecondOperand(), map);
} else if ("==".equals(op)) {
Expression left = be.getFirstOperand();
Expression right = be.getSecondOperand();
if (left instanceof Var var && right.isLiteral()) {
map.put(var.getName(), right.clone());
} else if (right instanceof Var var && left.isLiteral()) {
map.put(var.getName(), left.clone());
} else if (left instanceof Var leftVar && right instanceof Var rightVar) {
// to substitute internal variable with user-facing variable
if (isInternal(leftVar) && !isInternal(rightVar) && !isReturnVar(leftVar)) {
map.put(leftVar.getName(), right.clone());
} else if (isInternal(rightVar) && !isInternal(leftVar) && !isReturnVar(rightVar)) {
map.put(rightVar.getName(), left.clone());
} else if (isInternal(leftVar) && isInternal(rightVar)) {
// to substitute the lower-counter variable with the higher-counter one
boolean isLeftCounterLower = getCounter(leftVar) <= getCounter(rightVar);
Var lowerVar = isLeftCounterLower ? leftVar : rightVar;
Var higherVar = isLeftCounterLower ? rightVar : leftVar;
if (!isReturnVar(lowerVar) && !isFreshVar(higherVar))
map.putIfAbsent(lowerVar.getName(), higherVar.clone());
}
} else if (left instanceof Var var && !(right instanceof Var) && canSubstitute(var, right)) {
map.put(var.getName(), right.clone());
return;
}
if (!"==".equals(op))
return;

Expression left = be.getFirstOperand();
Expression right = be.getSecondOperand();
String leftKey = substitutionKey(left);
String rightKey = substitutionKey(right);

if (leftKey != null && right.isLiteral()) {
map.put(leftKey, right.clone());
} else if (rightKey != null && left.isLiteral()) {
map.put(rightKey, left.clone());
} else if (left instanceof Var leftVar && right instanceof Var rightVar) {
// to substitute internal variable with user-facing variable
if (isInternal(leftVar) && !isInternal(rightVar) && !isReturnVar(leftVar)) {
map.put(leftVar.getName(), right.clone());
} else if (isInternal(rightVar) && !isInternal(leftVar) && !isReturnVar(rightVar)) {
map.put(rightVar.getName(), left.clone());
} else if (isInternal(leftVar) && isInternal(rightVar)) {
// to substitute the lower-counter variable with the higher-counter one
boolean isLeftCounterLower = getCounter(leftVar) <= getCounter(rightVar);
Var lowerVar = isLeftCounterLower ? leftVar : rightVar;
Var higherVar = isLeftCounterLower ? rightVar : leftVar;
if (!isReturnVar(lowerVar) && !isFreshVar(higherVar))
map.putIfAbsent(lowerVar.getName(), higherVar.clone());
}
} else if (left instanceof Var var && !(right instanceof Var) && canSubstitute(var, right)) {
map.put(var.getName(), right.clone());
} else if (left instanceof FunctionInvocation && !(right instanceof Var)
&& !right.toString().contains(leftKey)) {
map.put(leftKey, right.clone());
}
}

private static String substitutionKey(Expression exp) {
if (exp instanceof Var var)
return var.getName();
if (exp instanceof FunctionInvocation)
return exp.toString();
return null;
}

/**
* Handles transitive variable equalities in the map (e.g. map: x -> y, y -> 1 => map: x -> 1, y -> 1)
*
Expand Down Expand Up @@ -129,14 +147,22 @@ private static boolean hasUsage(Expression exp, String name) {
if (left instanceof Var v && v.getName().equals(name)
&& (right.isLiteral() || (!(right instanceof Var) && canSubstitute(v, right))))
return false;
if (left instanceof FunctionInvocation && left.toString().equals(name)
&& (right.isLiteral() || (!(right instanceof Var) && !right.toString().contains(name))))
return false;
if (right instanceof Var v && v.getName().equals(name) && left.isLiteral())
return false;
if (right instanceof FunctionInvocation && right.toString().equals(name) && left.isLiteral())
return false;
}

// usage found
if (exp instanceof Var var && var.getName().equals(name)) {
return true;
}
if (exp instanceof FunctionInvocation && exp.toString().equals(name)) {
return true;
}

// recurse children
if (exp.hasChildren()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import liquidjava.rj_language.ast.AliasInvocation;
import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.FunctionInvocation;
import liquidjava.rj_language.ast.Ite;
import liquidjava.rj_language.ast.LiteralBoolean;
import liquidjava.rj_language.ast.LiteralInt;
Expand Down Expand Up @@ -1133,4 +1134,28 @@ void testFoldsAdjacentIntegerConstantsInLeftAssociatedArithmetic() {
assertEquals("x + 3", ExpressionSimplifier.simplify(xPlus1Plus2).getValue().toString());
assertEquals("x", ExpressionSimplifier.simplify(xPlus1Minus1).getValue().toString());
}

@Test
void testFunctionInvocationEqualitiesPropagateTransitively() {
// Given: size(x3) == size(x2) - 1 && size(x2) == size(x1) + 1 && size(x1) == 0
// Expected: size(x3) == 0
Expression x1 = new Var("x1");
Expression x2 = new Var("x2");
Expression x3 = new Var("x3");
Expression sizeX1 = new FunctionInvocation("size", List.of(x1));
Expression sizeX2 = new FunctionInvocation("size", List.of(x2));
Expression sizeX3 = new FunctionInvocation("size", List.of(x3));

Expression sizeX3EqualsSizeX2Minus1 = new BinaryExpression(sizeX3, "==",
new BinaryExpression(sizeX2, "-", new LiteralInt(1)));
Expression sizeX2EqualsSizeX1Plus1 = new BinaryExpression(sizeX2, "==",
new BinaryExpression(sizeX1, "+", new LiteralInt(1)));
Expression sizeX1Equals0 = new BinaryExpression(sizeX1, "==", new LiteralInt(0));
Expression fullExpression = new BinaryExpression(sizeX3EqualsSizeX2Minus1, "&&",
new BinaryExpression(sizeX2EqualsSizeX1Plus1, "&&", sizeX1Equals0));

ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);

assertEquals("size(x3) == 0", result.getValue().toString());
}
}
Loading