Arithmetic preprocessing

This commit is contained in:
Tibo De Peuter 2025-04-28 12:20:03 +02:00
parent 174855d7a3
commit 32165a90f5
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
3 changed files with 442 additions and 124 deletions

View file

@ -1,24 +1,13 @@
package interpreter
import io.Logger
import prolog.ast.arithmetic.Expression
import prolog.ast.logic.Clause
import prolog.ast.logic.Fact
import prolog.ast.logic.LogicOperand
import prolog.ast.logic.Rule
import prolog.ast.terms.Atom
import prolog.ast.terms.Body
import prolog.ast.terms.Goal
import prolog.ast.terms.Head
import prolog.ast.terms.Structure
import prolog.ast.terms.Term
import prolog.builtins.Conjunction
import prolog.builtins.Cut
import prolog.builtins.Disjunction
import prolog.builtins.Fail
import prolog.builtins.False
import prolog.builtins.Not
import prolog.builtins.Query
import prolog.builtins.True
import prolog.ast.terms.*
import prolog.builtins.*
/**
* Preprocessor for Prolog
@ -45,12 +34,14 @@ open class Preprocessor {
is Fact -> {
Fact(preprocess(clause.head) as Head)
}
is Rule -> {
Rule(
preprocess(clause.head) as Head,
preprocess(clause.body as Term) as Body
)
}
else -> clause
}
}
@ -65,23 +56,70 @@ open class Preprocessor {
Structure(Atom("fail"), emptyList()) -> Fail
Atom("!") -> Cut()
Structure(Atom("!"), emptyList()) -> Cut()
else -> {
is Structure -> {
// Preprocess the arguments first to recognize builtins
val args = term.arguments.map { preprocess(it) }
when {
term is Structure && term.functor == ",/2" -> {
val args = term.arguments.map { preprocess(it) }
// TODO Remove hardcoding by storing the functors as constants in operators?
// Logic
term.functor == ",/2" -> {
Conjunction(args[0] as LogicOperand, args[1] as LogicOperand)
}
term is Structure && term.functor == ";/2" -> {
val args = term.arguments.map { preprocess(it) }
term.functor == ";/2" -> {
Disjunction(args[0] as LogicOperand, args[1] as LogicOperand)
}
term is Structure && term.functor == "\\+/1" -> {
val args = term.arguments.map { preprocess(it) }
term.functor == "\\+/1" -> {
Not(args[0] as Goal)
}
// Arithmetic
term.functor == "=\\=/2" && args.all { it is Expression } -> {
EvaluatesToDifferent(args[0] as Expression, args[1] as Expression)
}
term.functor == "=:=/2" && args.all { it is Expression } -> {
EvaluatesTo(args[0] as Expression, args[1] as Expression)
}
term.functor == "is/2" && args.all { it is Expression } -> {
Is(args[0] as Expression, args[1] as Expression)
}
term.functor == "-/1" && args.all { it is Expression } -> {
Negate(args[0] as Expression)
}
term.functor == "-/2" && args.all { it is Expression } -> {
Subtract(args[0] as Expression, args[1] as Expression)
}
term.functor == "+/1" && args.all { it is Expression } -> {
Positive(args[0] as Expression)
}
term.functor == "+/2" && args.all { it is Expression } -> {
Add(args[0] as Expression, args[1] as Expression)
}
term.functor == "*/2" && args.all { it is Expression } -> {
Multiply(args[0] as Expression, args[1] as Expression)
}
term.functor == "//2" && args.all { it is Expression } -> {
Divide(args[0] as Expression, args[1] as Expression)
}
term.functor == "between/3" && args.all { it is Expression } -> {
Between(args[0] as Expression, args[1] as Expression, args[2] as Expression)
}
else -> term
}
}
else -> term
}
if (prepped != term || prepped::class != term::class) {

View file

@ -62,11 +62,11 @@ class EvaluatesTo(private val left: Expression, private val right: Expression) :
/**
* True when Number is the value to which Expr evaluates.
*/
class Is(private val left: Expression, private val right: Expression) :
Operator(Atom("is"), left, right), Satisfiable {
class Is(val number: Expression, val expr: Expression) :
Operator(Atom("is"), number, expr), Satisfiable {
override fun satisfy(subs: Substitutions): Answers {
val t1 = left.simplify(subs)
val t2 = right.simplify(subs)
val t1 = number.simplify(subs)
val t2 = expr.simplify(subs)
if (!atomic(t2.to, subs)) {
return sequenceOf(Result.failure(IllegalArgumentException("Right operand must be instantiated")))
@ -119,7 +119,7 @@ open class Subtract(private val expr1: Expression, private val expr2: Expression
/**
* Result = Expr1 * Expr2
*/
class Multiply(private val expr1: Expression, private val expr2: Expression) :
class Multiply(val expr1: Expression, val expr2: Expression) :
ArithmeticOperator(Atom("*"), expr1, expr2) {
override fun simplify(subs: Substitutions): Simplification {
val result = Variable("Result")

View file

@ -3,15 +3,12 @@ package interpreter
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import prolog.ast.arithmetic.Integer
import prolog.ast.terms.Atom
import prolog.ast.terms.CompoundTerm
import prolog.ast.terms.Term
import prolog.builtins.Conjunction
import prolog.builtins.Disjunction
import prolog.builtins.Cut
import prolog.builtins.Fail
import prolog.builtins.True
import prolog.builtins.Not
import prolog.ast.terms.Variable
import prolog.builtins.*
class PreprocessorTests {
class OpenPreprocessor : Preprocessor() {
@ -20,11 +17,318 @@ class PreprocessorTests {
}
}
companion object {
fun test(tests: Map<Term, Term>) {
for ((input, expected) in tests) {
val result = OpenPreprocessor().preprocess(input)
assertEquals(expected, result, "Expected preprocessed")
assertEquals(expected::class, result::class, "Expected same class")
}
}
}
@Nested
class `Arithmetic operators` {
@Test
fun `evaluates to different`() {
assertEquals(1, 2)
test(
mapOf(
Atom("=\\=") to Atom("=\\="),
CompoundTerm(Atom("=\\="), emptyList()) to CompoundTerm(Atom("=\\="), emptyList()),
Atom("EvaluatesToDifferent") to Atom("EvaluatesToDifferent"),
CompoundTerm(Atom("EvaluatesToDifferent"), emptyList()) to CompoundTerm(
Atom("EvaluatesToDifferent"),
emptyList()
),
CompoundTerm(Atom("=\\="), listOf(Atom("a"))) to CompoundTerm(
Atom("=\\="),
listOf(Atom("a"))
),
CompoundTerm(Atom("=\\="), listOf(Integer(1))) to CompoundTerm(
Atom("=\\="),
listOf(Integer(1))
),
CompoundTerm(Atom("=\\="), listOf(Atom("=\\="))) to CompoundTerm(
Atom("=\\="),
listOf(Atom("=\\="))
),
CompoundTerm(Atom("=\\="), listOf(Integer(1), Integer(2))) to EvaluatesToDifferent(
Integer(1), Integer(2)
)
)
)
}
@Test
fun `evaluates to`() {
test(
mapOf(
Atom("=:=") to Atom("=:="),
CompoundTerm(Atom("=:="), emptyList()) to CompoundTerm(Atom("=:="), emptyList()),
Atom("EvaluatesTo") to Atom("EvaluatesTo"),
CompoundTerm(Atom("EvaluatesTo"), emptyList()) to CompoundTerm(
Atom("EvaluatesTo"),
emptyList()
),
CompoundTerm(Atom("=:="), listOf(Atom("a"))) to CompoundTerm(
Atom("=:="),
listOf(Atom("a"))
),
CompoundTerm(Atom("=:="), listOf(Atom("=:="))) to CompoundTerm(
Atom("=:="),
listOf(Atom("=:="))
),
CompoundTerm(Atom("=:="), listOf(Integer(1), Integer(2))) to EvaluatesTo(
Integer(1), Integer(2)
)
)
)
}
@Test
fun `is`() {
test(
mapOf(
Atom("is") to Atom("is"),
CompoundTerm(Atom("is"), emptyList()) to CompoundTerm(Atom("is"), emptyList()),
Atom("Is") to Atom("Is"),
CompoundTerm(Atom("Is"), emptyList()) to CompoundTerm(Atom("Is"), emptyList()),
CompoundTerm(Atom("is"), listOf(Atom("a"))) to CompoundTerm(
Atom("is"),
listOf(Atom("a"))
),
CompoundTerm(Atom("is"), listOf(Integer(1))) to CompoundTerm(
Atom("is"),
listOf(Integer(1))
),
CompoundTerm(Atom("is"), listOf(Atom("is"))) to CompoundTerm(
Atom("is"),
listOf(Atom("is"))
),
CompoundTerm(Atom("is"), listOf(Integer(1), Integer(2))) to Is(
Integer(1), Integer(2)
)
)
)
}
@Test
fun `negate and subtract`() {
test(
mapOf(
Atom("-") to Atom("-"),
CompoundTerm(Atom("-"), emptyList()) to CompoundTerm(Atom("-"), emptyList()),
Atom("Negate") to Atom("Negate"),
CompoundTerm(Atom("Negate"), emptyList()) to CompoundTerm(
Atom("Negate"),
emptyList()
),
CompoundTerm(Atom("-"), listOf(Atom("a"))) to CompoundTerm(
Atom("-"),
listOf(Atom("a"))
),
CompoundTerm(Atom("-"), listOf(Integer(1))) to Negate(Integer(1)),
CompoundTerm(Atom("-"), listOf(Atom("-"))) to CompoundTerm(
Atom("-"),
listOf(Atom("-"))
),
CompoundTerm(Atom("-"), listOf(Integer(1), Integer(2))) to Subtract(
Integer(1), Integer(2)
),
CompoundTerm(Atom("-"), listOf(Atom("1"), Atom("2"))) to CompoundTerm(
Atom("-"),
listOf(Atom("1"), Atom("2"))
),
CompoundTerm(Atom("-"), listOf(Integer(1), Integer(2), Integer(3))) to CompoundTerm(
Atom("-"),
listOf(Integer(1), Integer(2), Integer(3))
)
)
)
}
@Test
fun `positive and add`() {
test(
mapOf(
Atom("+") to Atom("+"),
CompoundTerm(Atom("+"), emptyList()) to CompoundTerm(Atom("+"), emptyList()),
Atom("Positive") to Atom("Positive"),
CompoundTerm(Atom("Positive"), emptyList()) to CompoundTerm(
Atom("Positive"),
emptyList()
),
CompoundTerm(Atom("+"), listOf(Atom("a"))) to CompoundTerm(
Atom("+"),
listOf(Atom("a"))
),
CompoundTerm(Atom("+"), listOf(Integer(1))) to Positive(Integer(1)),
CompoundTerm(Atom("+"), listOf(Atom("+"))) to CompoundTerm(
Atom("+"),
listOf(Atom("+"))
),
CompoundTerm(Atom("+"), listOf(Integer(1), Integer(2))) to Add(
Integer(1), Integer(2)
),
CompoundTerm(Atom("+"), listOf(Atom("1"), Atom("2"))) to CompoundTerm(
Atom("+"),
listOf(Atom("1"), Atom("2"))
),
CompoundTerm(Atom("+"), listOf(Integer(1), Integer(2), Integer(3))) to CompoundTerm(
Atom("+"),
listOf(Integer(1), Integer(2), Integer(3))
)
)
)
}
@Test
fun multiply() {
test(
mapOf(
Atom("*") to Atom("*"),
CompoundTerm(Atom("*"), emptyList()) to CompoundTerm(Atom("*"), emptyList()),
Atom("Multiply") to Atom("Multiply"),
CompoundTerm(Atom("Multiply"), emptyList()) to CompoundTerm(
Atom("Multiply"),
emptyList()
),
CompoundTerm(Atom("*"), listOf(Atom("a"))) to CompoundTerm(
Atom("*"),
listOf(Atom("a"))
),
CompoundTerm(Atom("*"), listOf(Integer(1))) to CompoundTerm(Atom("*"), listOf(Integer(1))),
CompoundTerm(Atom("*"), listOf(Atom("*"))) to CompoundTerm(
Atom("*"),
listOf(Atom("*"))
),
CompoundTerm(Atom("*"), listOf(Integer(1), Integer(2))) to Multiply(
Integer(1), Integer(2)
),
CompoundTerm(Atom("*"), listOf(Atom("1"), Atom("2"))) to CompoundTerm(
Atom("*"),
listOf(Atom("1"), Atom("2"))
),
CompoundTerm(Atom("*"), listOf(Integer(1), Integer(2), Integer(3))) to CompoundTerm(
Atom("*"),
listOf(Integer(1), Integer(2), Integer(3))
)
)
)
}
@Test
fun divide() {
test(
mapOf(
Atom("/") to Atom("/"),
CompoundTerm(Atom("/"), emptyList()) to CompoundTerm(Atom("/"), emptyList()),
Atom("Divide") to Atom("Divide"),
CompoundTerm(Atom("Divide"), emptyList()) to CompoundTerm(
Atom("Divide"),
emptyList()
),
CompoundTerm(Atom("/"), listOf(Atom("a"))) to CompoundTerm(
Atom("/"),
listOf(Atom("a"))
),
CompoundTerm(Atom("/"), listOf(Integer(1))) to CompoundTerm(Atom("/"), listOf(Integer(1))),
CompoundTerm(Atom("/"), listOf(Atom("/"))) to CompoundTerm(
Atom("/"),
listOf(Atom("/"))
),
CompoundTerm(Atom("/"), listOf(Integer(1), Integer(2))) to Divide(
Integer(1), Integer(2)
),
CompoundTerm(Atom("/"), listOf(Atom("1"), Atom("2"))) to CompoundTerm(
Atom("/"),
listOf(Atom("1"), Atom("2"))
),
CompoundTerm(Atom("/"), listOf(Integer(1), Integer(2), Integer(3))) to CompoundTerm(
Atom("/"),
listOf(Integer(1), Integer(2), Integer(3))
)
)
)
}
@Test
fun between() {
test(
mapOf(
Atom("between") to Atom("between"),
CompoundTerm(Atom("between"), emptyList()) to CompoundTerm(
Atom("between"),
emptyList()
),
Atom("Between") to Atom("Between"),
CompoundTerm(Atom("Between"), emptyList()) to CompoundTerm(
Atom("Between"),
emptyList()
),
CompoundTerm(Atom("between"), listOf(Atom("a"))) to CompoundTerm(
Atom("between"),
listOf(Atom("a"))
),
CompoundTerm(Atom("between"), listOf(Integer(1))) to CompoundTerm(
Atom("between"),
listOf(Integer(1))
),
CompoundTerm(Atom("between"), listOf(Atom("between"))) to CompoundTerm(
Atom("between"),
listOf(Atom("between"))
),
CompoundTerm(Atom("between"), listOf(Integer(1), Integer(2))) to CompoundTerm(
Atom("between"),
listOf(Integer(1), Integer(2))
),
CompoundTerm(Atom("between"), listOf(Integer(1), Integer(2), Integer(3))) to Between(
Integer(1), Integer(2), Integer(3)
),
)
)
}
@Test
fun `fun combinations`() {
/*
* [X - 1] is [(1 + 2) * ((12 / 3) - 0)]
* should return
* X = 13
*/
val sum_ = CompoundTerm(Atom("+"), listOf(Integer(1), Integer(2)))
val sum = Add(Integer(1), Integer(2))
val div_ = CompoundTerm(Atom("/"), listOf(Integer(12), Integer(3)))
val div = Divide(Integer(12), Integer(3))
val sub_ = CompoundTerm(Atom("-"), listOf(div_, Integer(0)))
val sub = Subtract(div, Integer(0))
val right_ = CompoundTerm(Atom("*"), listOf(sum_, sub_))
val right = Multiply(sum, sub)
val left_ = CompoundTerm(Atom("-"), listOf(Variable("X"), Integer(1)))
val left = Subtract(Variable("X"), Integer(1))
val expr_ = CompoundTerm(Atom("is"), listOf(left_, right_))
val expr = Is(left, right)
val result = OpenPreprocessor().preprocess(expr_)
assertEquals(expr, result)
assertEquals(Is::class, result::class)
val `is` = result as Is
assertEquals(left, `is`.number)
assertEquals(Subtract::class, `is`.number::class)
assertEquals(right, `is`.expr)
assertEquals(Multiply::class, `is`.expr::class)
val multiply = `is`.expr as Multiply
assertEquals(sum, multiply.expr1)
assertEquals(Add::class, multiply.expr1::class)
}
}
@ -34,123 +338,99 @@ class PreprocessorTests {
@Test
fun fail() {
val tests = mapOf(
Atom("fail") to Fail,
CompoundTerm(Atom("fail"), emptyList()) to Fail,
Atom("Fail") to Atom("Fail"),
CompoundTerm(Atom("Fail"), emptyList()) to CompoundTerm(Atom("Fail"), emptyList()),
CompoundTerm(Atom("fail"), listOf(Atom("a"))) to CompoundTerm(Atom("fail"), listOf(Atom("a"))),
CompoundTerm(Atom("fail"), listOf(Atom("fail"))) to CompoundTerm(Atom("fail"), listOf(Fail))
test(
mapOf(
Atom("fail") to Fail,
CompoundTerm(Atom("fail"), emptyList()) to Fail,
Atom("Fail") to Atom("Fail"),
CompoundTerm(Atom("Fail"), emptyList()) to CompoundTerm(Atom("Fail"), emptyList()),
CompoundTerm(Atom("fail"), listOf(Atom("a"))) to CompoundTerm(Atom("fail"), listOf(Atom("a"))),
CompoundTerm(Atom("fail"), listOf(Atom("fail"))) to CompoundTerm(Atom("fail"), listOf(Fail))
)
)
for ((input, expected) in tests) {
val result = preprocessor.preprocess(input)
assertEquals(expected, result, "Expected preprocessed")
assertEquals(expected::class, result::class, "Expected same class")
}
}
@Test
fun `true`() {
val tests = mapOf(
Atom("true") to True,
CompoundTerm(Atom("true"), emptyList()) to True,
Atom("True") to Atom("True"),
CompoundTerm(Atom("True"), emptyList()) to CompoundTerm(Atom("True"), emptyList()),
CompoundTerm(Atom("true"), listOf(Atom("a"))) to CompoundTerm(Atom("true"), listOf(Atom("a"))),
CompoundTerm(Atom("true"), listOf(Atom("true"))) to CompoundTerm(Atom("true"), listOf(True))
test(
mapOf(
Atom("true") to True,
CompoundTerm(Atom("true"), emptyList()) to True,
Atom("True") to Atom("True"),
CompoundTerm(Atom("True"), emptyList()) to CompoundTerm(Atom("True"), emptyList()),
CompoundTerm(Atom("true"), listOf(Atom("a"))) to CompoundTerm(Atom("true"), listOf(Atom("a"))),
CompoundTerm(Atom("true"), listOf(Atom("true"))) to CompoundTerm(Atom("true"), listOf(True))
)
)
for ((input, expected) in tests) {
val result = preprocessor.preprocess(input)
assertEquals(expected, result, "Expected preprocessed")
assertEquals(expected::class, result::class, "Expected same class")
}
}
@Test
fun cut() {
val tests = mapOf(
Atom("!") to Cut(),
CompoundTerm(Atom("!"), emptyList()) to Cut(),
CompoundTerm(Atom("!"), listOf(Atom("a"))) to CompoundTerm(Atom("!"), listOf(Atom("a"))),
CompoundTerm(Atom("!"), listOf(Atom("!"))) to CompoundTerm(Atom("!"), listOf(Cut()))
test(
mapOf(
Atom("!") to Cut(),
CompoundTerm(Atom("!"), emptyList()) to Cut(),
CompoundTerm(Atom("!"), listOf(Atom("a"))) to CompoundTerm(Atom("!"), listOf(Atom("a"))),
CompoundTerm(Atom("!"), listOf(Atom("!"))) to CompoundTerm(Atom("!"), listOf(Cut()))
)
)
for ((input, expected) in tests) {
val result = preprocessor.preprocess(input)
assertEquals(expected, result, "Expected preprocessed")
assertEquals(expected::class, result::class, "Expected same class")
}
}
@Test
fun conjunction() {
val tests = mapOf(
CompoundTerm(Atom(","), listOf(Atom("a"), Atom("b"))) to Conjunction(Atom("a"), Atom("b")),
CompoundTerm(Atom(","), listOf(Atom("a"), Atom("b"), Atom("c"))) to CompoundTerm(
Atom(","),
listOf(Atom("a"), Atom("b"), Atom("c"))
),
// Nested conjunctions
CompoundTerm(
Atom(","),
listOf(Atom("a"), CompoundTerm(Atom(","), listOf(Atom("b"), Atom("c"))))
) to Conjunction(Atom("a"), Conjunction(Atom("b"), Atom("c"))),
test(
mapOf(
CompoundTerm(Atom(","), listOf(Atom("a"), Atom("b"))) to Conjunction(Atom("a"), Atom("b")),
CompoundTerm(Atom(","), listOf(Atom("a"), Atom("b"), Atom("c"))) to CompoundTerm(
Atom(","),
listOf(Atom("a"), Atom("b"), Atom("c"))
),
// Nested conjunctions
CompoundTerm(
Atom(","),
listOf(Atom("a"), CompoundTerm(Atom(","), listOf(Atom("b"), Atom("c"))))
) to Conjunction(Atom("a"), Conjunction(Atom("b"), Atom("c"))),
)
)
for ((input, expected) in tests) {
val result = preprocessor.preprocess(input)
assertEquals(expected, result, "Expected preprocessed")
assertEquals(expected::class, result::class, "Expected same class")
}
}
@Test
fun disjunction() {
val tests = mapOf(
CompoundTerm(Atom(";"), listOf(Atom("a"), Atom("b"))) to Disjunction(Atom("a"), Atom("b")),
CompoundTerm(Atom(";"), listOf(Atom("a"), Atom("b"), Atom("c"))) to CompoundTerm(
Atom(";"),
listOf(Atom("a"), Atom("b"), Atom("c"))
),
// Nested disjunctions
CompoundTerm(
Atom(";"),
listOf(Atom("a"), CompoundTerm(Atom(";"), listOf(Atom("b"), Atom("c"))))
) to Disjunction(Atom("a"), Disjunction(Atom("b"), Atom("c"))),
test(
mapOf(
CompoundTerm(Atom(";"), listOf(Atom("a"), Atom("b"))) to Disjunction(Atom("a"), Atom("b")),
CompoundTerm(Atom(";"), listOf(Atom("a"), Atom("b"), Atom("c"))) to CompoundTerm(
Atom(";"),
listOf(Atom("a"), Atom("b"), Atom("c"))
),
// Nested disjunctions
CompoundTerm(
Atom(";"),
listOf(Atom("a"), CompoundTerm(Atom(";"), listOf(Atom("b"), Atom("c"))))
) to Disjunction(Atom("a"), Disjunction(Atom("b"), Atom("c"))),
)
)
for ((input, expected) in tests) {
val result = preprocessor.preprocess(input)
assertEquals(expected, result, "Expected preprocessed")
assertEquals(expected::class, result::class, "Expected same class")
}
}
@Test
fun not() {
val tests = mapOf(
CompoundTerm(Atom("\\+"), listOf(Atom("a"))) to Not(Atom("a")),
CompoundTerm(Atom("\\+"), listOf(Atom("a"), Atom("b"))) to CompoundTerm(
Atom("\\+"),
listOf(Atom("a"), Atom("b"))
),
// Nested not
CompoundTerm(
Atom("foo"),
listOf(
Atom("bar"),
CompoundTerm(Atom("\\+"), listOf(CompoundTerm(Atom("\\+"), listOf(Atom("baz")))))
)
) to CompoundTerm(Atom("foo"), listOf(Atom("bar"), Not(Not(Atom("baz"))))),
test(
mapOf(
CompoundTerm(Atom("\\+"), listOf(Atom("a"))) to Not(Atom("a")),
CompoundTerm(Atom("\\+"), listOf(Atom("a"), Atom("b"))) to CompoundTerm(
Atom("\\+"),
listOf(Atom("a"), Atom("b"))
),
// Nested not
CompoundTerm(
Atom("foo"),
listOf(
Atom("bar"),
CompoundTerm(Atom("\\+"), listOf(CompoundTerm(Atom("\\+"), listOf(Atom("baz")))))
)
) to CompoundTerm(Atom("foo"), listOf(Atom("bar"), Not(Not(Atom("baz"))))),
)
)
for ((input, expected) in tests) {
val result = preprocessor.preprocess(input)
assertEquals(expected, result, "Expected preprocessed")
assertEquals(expected::class, result::class, "Expected same class")
}
}
}
}