From 4a6850527f4e53cd10604ef2e197c622c795283a Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Tue, 15 Apr 2025 18:23:04 +0200 Subject: [PATCH] feat: Floating point arithmetic --- src/prolog/ast/arithmetic/Float.kt | 34 +++ src/prolog/ast/arithmetic/Integer.kt | 40 +++ src/prolog/ast/arithmetic/Number.kt | 10 + src/prolog/ast/terms/Integer.kt | 30 --- src/prolog/ast/terms/Term.kt | 5 +- src/prolog/builtins/arithmeticOperators.kt | 12 +- src/prolog/logic/arithmetic.kt | 26 +- src/prolog/logic/unification.kt | 15 +- .../builtins/ArithmeticOperatorsTests.kt | 149 +++++++++- .../prolog/builtins/ControlOperatorsTests.kt | 2 +- tests/prolog/builtins/UnificationTest.kt | 2 +- tests/prolog/logic/ArithmeticTests.kt | 255 +++++++++++++++++- tests/prolog/logic/UnifyTest.kt | 2 +- 13 files changed, 527 insertions(+), 55 deletions(-) create mode 100644 src/prolog/ast/arithmetic/Float.kt create mode 100644 src/prolog/ast/arithmetic/Integer.kt create mode 100644 src/prolog/ast/arithmetic/Number.kt delete mode 100644 src/prolog/ast/terms/Integer.kt diff --git a/src/prolog/ast/arithmetic/Float.kt b/src/prolog/ast/arithmetic/Float.kt new file mode 100644 index 0000000..3bbf694 --- /dev/null +++ b/src/prolog/ast/arithmetic/Float.kt @@ -0,0 +1,34 @@ +package prolog.ast.arithmetic + +import prolog.Substitutions + +class Float(override val value: kotlin.Float): Number { + // Floats are already evaluated + override fun simplify(subs: Substitutions): Simplification = Simplification(this, this) + + override fun toString(): String = value.toString() + + override operator fun plus(other: Number): Number = when (other) { + is Float -> Float(value + other.value) + is Integer -> Float(value + other.value.toFloat()) + else -> throw IllegalArgumentException("Cannot add $this and $other") + } + + override operator fun minus(other: Number): Number = when (other) { + is Float -> Float(value - other.value) + is Integer -> Float(value - other.value.toFloat()) + else -> throw IllegalArgumentException("Cannot subtract $this and $other") + } + + override operator fun div(other: Number): Number = when (other) { + is Float -> Float(value / other.value) + is Integer -> Float(value / other.value.toFloat()) + else -> throw IllegalArgumentException("Cannot divide $this and $other") + } + + override operator fun times(other: Number): Number = when (other) { + is Float -> Float(value * other.value) + is Integer -> Float(value * other.value.toFloat()) + else -> throw IllegalArgumentException("Cannot multiply $this and $other") + } +} \ No newline at end of file diff --git a/src/prolog/ast/arithmetic/Integer.kt b/src/prolog/ast/arithmetic/Integer.kt new file mode 100644 index 0000000..50028a9 --- /dev/null +++ b/src/prolog/ast/arithmetic/Integer.kt @@ -0,0 +1,40 @@ +package prolog.ast.arithmetic + +import prolog.Substitutions + +data class Integer(override val value: Int) : Number { + // Integers are already evaluated + override fun simplify(subs: Substitutions): Simplification = Simplification(this, this) + + override fun toString(): String = value.toString() + + override operator fun plus(other: Number): Number = when (other) { + is Float -> Float(value + other.value) + is Integer -> Integer(value + other.value) + else -> throw IllegalArgumentException("Cannot add $this and $other") + } + + override operator fun minus(other: Number): Number = when (other) { + is Float -> Float(value - other.value) + is Integer -> Integer(value - other.value) + else -> throw IllegalArgumentException("Cannot subtract $this and $other") + } + + override operator fun div(other: Number): Number = when (other) { + is Float -> Float(value / other.value) + is Integer -> { + if (value / other.value * other.value == value) { + Integer(value / other.value) + } else { + Float(value / other.value.toFloat()) + } + } + else -> throw IllegalArgumentException("Cannot divide $this and $other") + } + + override operator fun times(other: Number): Number = when (other) { + is Float -> Float(value * other.value) + is Integer -> Integer(value * other.value) + else -> throw IllegalArgumentException("Cannot multiply $this and $other") + } +} diff --git a/src/prolog/ast/arithmetic/Number.kt b/src/prolog/ast/arithmetic/Number.kt new file mode 100644 index 0000000..959152d --- /dev/null +++ b/src/prolog/ast/arithmetic/Number.kt @@ -0,0 +1,10 @@ +package prolog.ast.arithmetic + +interface Number: Expression { + val value: kotlin.Number + + fun plus(other: Number): Number + fun minus(other: Number): Number + fun times(other: Number): Number + fun div(other: Number): Number +} \ No newline at end of file diff --git a/src/prolog/ast/terms/Integer.kt b/src/prolog/ast/terms/Integer.kt deleted file mode 100644 index d0c96c9..0000000 --- a/src/prolog/ast/terms/Integer.kt +++ /dev/null @@ -1,30 +0,0 @@ -package prolog.ast.terms - -import prolog.Substitutions -import prolog.ast.arithmetic.Expression -import prolog.ast.arithmetic.Simplification - -data class Integer(val value: Int): Expression { - // Integers are already evaluated - override fun simplify(subs: Substitutions): Simplification = Simplification(this, this) - - override fun toString(): String { - return value.toString() - } - - operator fun plus(other: Integer): Integer { - return Integer(value + other.value) - } - - operator fun minus(other: Integer): Integer { - return Integer(value - other.value) - } - - operator fun times(other: Integer): Integer { - return Integer(value * other.value) - } - - operator fun div(other: Integer): Integer { - return Integer(value / other.value) - } -} diff --git a/src/prolog/ast/terms/Term.kt b/src/prolog/ast/terms/Term.kt index ad5981f..0fdad49 100644 --- a/src/prolog/ast/terms/Term.kt +++ b/src/prolog/ast/terms/Term.kt @@ -5,8 +5,9 @@ import prolog.logic.compare /** * Value in Prolog. * - * A [Term] is either a [Variable], [Atom], [Integer], float or [CompoundTerm]. - * In addition, SWI-Prolog also defines the type string. + * A [Term] is either a [Variable], [Atom], [Integer][prolog.ast.arithmetic.Integer], + * [Float][prolog.ast.arithmetic.Float] or [CompoundTerm]. + * In addition, SWI-Prolog also defines the type TODO string. */ interface Term : Comparable { override fun compareTo(other: Term): Int = compare(this, other, emptyMap()) diff --git a/src/prolog/builtins/arithmeticOperators.kt b/src/prolog/builtins/arithmeticOperators.kt index 8e44df8..2463313 100644 --- a/src/prolog/builtins/arithmeticOperators.kt +++ b/src/prolog/builtins/arithmeticOperators.kt @@ -4,6 +4,7 @@ import prolog.Answers import prolog.Substitutions import prolog.ast.arithmetic.ArithmeticOperator import prolog.ast.arithmetic.Expression +import prolog.ast.arithmetic.Integer import prolog.ast.arithmetic.Simplification import prolog.ast.logic.Satisfiable import prolog.ast.terms.* @@ -115,7 +116,6 @@ open class Subtract(private val expr1: Expression, private val expr2: Expression } } -// TODO Expr * Expr /** * Result = Expr1 * Expr2 */ @@ -129,7 +129,15 @@ class Multiply(private val expr1: Expression, private val expr2: Expression) : } } -// TODO Expr / Expr +class Divide(private val expr1: Expression, private val expr2: Expression) : + ArithmeticOperator(Atom("/"), expr1, expr2) { + override fun simplify(subs: Substitutions): Simplification { + val result = Variable("Result") + val map = div(expr1, expr2, result, subs) + val simplification = result.simplify(map.first().getOrThrow()) + return Simplification(this, simplification.to) + } +} // TODO Expr mod Expr diff --git a/src/prolog/logic/arithmetic.kt b/src/prolog/logic/arithmetic.kt index 8a26a52..0d4d341 100644 --- a/src/prolog/logic/arithmetic.kt +++ b/src/prolog/logic/arithmetic.kt @@ -1,12 +1,12 @@ package prolog.logic import prolog.Answers -import prolog.Substitution import prolog.Substitutions import prolog.ast.arithmetic.Expression -import prolog.ast.terms.Integer +import prolog.ast.arithmetic.Integer import prolog.ast.terms.Term import prolog.ast.terms.Variable +import prolog.ast.arithmetic.Number /** * Low and High are integers, High ≥Low. @@ -78,18 +78,24 @@ fun succ(term1: Expression, term2: Expression, subs: Substitutions): Answers { * At least two of the three arguments must be instantiated to integers. */ fun plus(term1: Expression, term2: Expression, term3: Expression, subs: Substitutions): Answers = - operate(term1, term2, term3, subs, Integer::plus, Integer::minus) + operate(term1, term2, term3, subs, Number::plus, Number::minus) + +fun minus(term1: Expression, term2: Expression, term3: Expression, subs: Substitutions): Answers = + operate(term1, term2, term3, subs, Number::minus, Number::plus) fun mul(term1: Expression, term2: Expression, term3: Expression, subs: Substitutions): Answers = - operate(term1, term2, term3, subs, Integer::times, Integer::div) + operate(term1, term2, term3, subs, Number::times, Number::div) + +fun div(term1: Expression, term2: Expression, term3: Expression, subs: Substitutions): Answers = + operate(term1, term2, term3, subs, Number::div, Number::times) fun operate( term1: Expression, term2: Expression, term3: Expression, subs: Substitutions, - op: (Integer, Integer) -> Integer, - inverseOp: (Integer, Integer) -> Integer + op: (Number, Number) -> Number, + inverseOp: (Number, Number) -> Number ): Answers = sequence { val t1 = applySubstitution(term1, subs) val t2 = applySubstitution(term2, subs) @@ -101,8 +107,8 @@ fun operate( val e2 = t2.simplify(subs) val e3 = t3.simplify(subs) - val int3Value = op(e1.to as Integer, e2.to as Integer) - if (int3Value == e3.to as Integer) { + val int3Value = op(e1.to as Number, e2.to as Number) + if (equivalent(int3Value, e3.to, emptyMap())) { val opSubs: Substitutions = listOfNotNull(e1.mapped, e2.mapped, e3.mapped) .filter{ pair: Pair? -> pair != null && !subs.contains(pair.first) } .toMap() @@ -114,7 +120,7 @@ fun operate( val e1 = t1.simplify(subs) val e2 = t2.simplify(subs) - val int3Value = op(e1.to as Integer, e2.to as Integer) + val int3Value = op(e1.to as Number, e2.to as Number) val int3 = t3 as Variable yield(Result.success(mapOf(int3 to int3Value) + listOfNotNull(e1.mapped, e2.mapped))) } @@ -124,7 +130,7 @@ fun operate( val e = if (nonvariable(t1, subs)) t1.simplify(subs) else t2.simplify(subs) val e3 = t3.simplify(subs) - val value = inverseOp(e3.to as Integer, e.to as Integer) + val value = inverseOp(e3.to as Number, e.to as Number) val int = t as Variable yield(Result.success(mapOf(int to value) + listOfNotNull(e.mapped, e3.mapped))) } diff --git a/src/prolog/logic/unification.kt b/src/prolog/logic/unification.kt index a7b6274..ad24281 100644 --- a/src/prolog/logic/unification.kt +++ b/src/prolog/logic/unification.kt @@ -7,6 +7,9 @@ import prolog.ast.arithmetic.Expression import prolog.ast.logic.LogicOperator import prolog.ast.terms.* import kotlin.NoSuchElementException +import prolog.ast.arithmetic.Number +import prolog.ast.arithmetic.Integer +import prolog.ast.arithmetic.Float // Apply substitutions to a term fun applySubstitution(term: Term, subs: Substitutions): Term = when { @@ -101,6 +104,7 @@ fun equivalent(term1: Term, term2: Term, subs: Substitutions): Boolean { term1 is Atom && term2 is Atom -> compare(term1, term2, subs) == 0 term1 is Structure && term2 is Structure -> compare(term1, term2, subs) == 0 term1 is Integer && term2 is Integer -> compare(term1, term2, subs) == 0 + term1 is Number && term2 is Number -> compare(term1, term2, subs) == 0 term1 is Variable && term2 is Variable -> term1 == term2 term1 is Variable -> term1 in subs && equivalent(subs[term1]!!, term2, subs) term2 is Variable -> term2 in subs && equivalent(subs[term2]!!, term1, subs) @@ -119,16 +123,17 @@ fun compare(term1: Term, term2: Term, subs: Substitutions): Int { is Variable -> { when (t2) { is Variable -> t1.name.compareTo(t2.name) - is Integer -> -1 + is Number -> -1 is Atom -> -1 is Structure -> -1 else -> throw IllegalArgumentException("Cannot compare $t1 with $t2") } } - is Integer -> { + is Number -> { when (t2) { is Variable -> 1 - is Integer -> t1.value.compareTo(t2.value) + is Integer -> (t1.value as Int).compareTo(t2.value) + is Float -> (t1.value as kotlin.Float).compareTo(t2.value) is Atom -> -1 is Structure -> -1 else -> throw IllegalArgumentException("Cannot compare $t1 with $t2") @@ -137,7 +142,7 @@ fun compare(term1: Term, term2: Term, subs: Substitutions): Int { is Atom -> { when (t2) { is Variable -> 1 - is Integer -> 1 + is Number -> 1 is Atom -> t1.name.compareTo(t2.name) is Structure -> -1 else -> throw IllegalArgumentException("Cannot compare $t1 with $t2") @@ -146,7 +151,7 @@ fun compare(term1: Term, term2: Term, subs: Substitutions): Int { is Structure -> { when (t2) { is Variable -> 1 - is Integer -> 1 + is Number -> 1 is Atom -> 1 is Structure -> { val arityComparison = t1.arguments.size.compareTo(t2.arguments.size) diff --git a/tests/prolog/builtins/ArithmeticOperatorsTests.kt b/tests/prolog/builtins/ArithmeticOperatorsTests.kt index af5fbbb..6ce5980 100644 --- a/tests/prolog/builtins/ArithmeticOperatorsTests.kt +++ b/tests/prolog/builtins/ArithmeticOperatorsTests.kt @@ -4,8 +4,8 @@ import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import prolog.Substitutions -import prolog.ast.terms.Integer -import prolog.ast.terms.Term +import prolog.ast.arithmetic.Float +import prolog.ast.arithmetic.Integer import prolog.ast.terms.Variable import prolog.logic.equivalent @@ -285,6 +285,111 @@ class ArithmeticOperatorsTests { assertEquals(0, result.size, "X should not be equal to Y") } + @Test + fun `var is mul`() { + val op = Is( + Variable("X"), + Multiply(Integer(2), Integer(3)) + ) + + val result = op.satisfy(emptyMap()).toList() + + assertEquals(1, result.size, "X should be equal to 6") + assertTrue(result[0].isSuccess, "X should be equal to 6") + val subs = result[0].getOrNull()!! + assertEquals(1, subs.size, "X should be rebound") + assertTrue( + equivalent(Integer(6), subs[Variable("X")]!!, emptyMap()), + "X should be equal to 6" + ) + } + + @Test + fun `bound-var is mul`() { + val t1 = Variable("X") + + val op = Is( + t1, + Multiply(Integer(2), Integer(3)) + ) + val map: Substitutions = mapOf(t1 to Integer(6)) + + val result = op.satisfy(map).toList() + + assertEquals(1, result.size, "X should be equal to 6") + assertTrue(result[0].isSuccess, "X should be equal to 6") + val subs = result[0].getOrNull()!! + assertTrue(subs.isEmpty(), "X should not be rebound") + } + + @Test + fun `var is bound-to-mul`() { + val t1 = Variable("X") + val t2 = Variable("Y") + + val op = Is(t1, t2) + val map: Substitutions = mapOf(t1 to Integer(6), t2 to Multiply(Integer(2), Integer(3))) + + val result = op.satisfy(map).toList() + + assertEquals(1, result.size, "X should be equal to Y") + assertTrue(result[0].isSuccess, "X should be equal to Y") + val subs = result[0].getOrNull()!! + assertTrue(subs.isEmpty(), "X should not be rebound") + } + + @Test + fun `2 is 4 div 2`() { + val op = Is( + Integer(2), + Divide(Integer(4), Integer(2)) + ) + + val result = op.satisfy(emptyMap()).toList() + + assertEquals(1, result.size, "2 should be equal to 4 / 2") + assertTrue(result[0].isSuccess, "2 should be equal to 4 / 2") + val subs = result[0].getOrNull()!! + assertTrue(subs.isEmpty(), "2 should not be rebound") + } + + @Test + fun `4 div 2 is var`() { + val op = Is( + Variable("X"), + Divide(Integer(4), Integer(2)) + ) + + val result = op.satisfy(emptyMap()).toList() + + assertEquals(1, result.size, "X should be equal to 2") + assertTrue(result[0].isSuccess, "X should be equal to 2") + val subs = result[0].getOrNull()!! + assertEquals(1, subs.size, "X should be rebound") + assertTrue( + equivalent(Integer(2), subs[Variable("X")]!!, emptyMap()), + "X should be equal to 2" + ) + } + + @Test + fun `bound-var is 4 div 2`() { + val t1 = Variable("X") + + val op = Is( + t1, + Divide(Integer(4), Integer(2)) + ) + val map: Substitutions = mapOf(t1 to Integer(2)) + + val result = op.satisfy(map).toList() + + assertEquals(1, result.size, "X should be equal to 2") + assertTrue(result[0].isSuccess, "X should be equal to 2") + val subs = result[0].getOrNull()!! + assertTrue(subs.isEmpty(), "X should not be rebound") + } + /** * ?- between(1, 2, X), Y is 1 + X. * X = 1, Y = 2 ; @@ -441,4 +546,44 @@ class ArithmeticOperatorsTests { assertEquals(Integer(6), result.to, "2 * 3 should be equal to 6") } + + @Test + fun `Divide 1 and 1 to get 1`() { + val t1 = Integer(1) + val t2 = Integer(1) + + val result = Divide(t1, t2).simplify(emptyMap()) + + assertEquals(Integer(1), result.to, "1 / 1 should be equal to 1") + } + + @Test + fun `Divide 2 and 2 to get 1`() { + val t1 = Integer(2) + val t2 = Integer(2) + + val result = Divide(t1, t2).simplify(emptyMap()) + + assertEquals(Integer(1), result.to, "2 / 2 should be equal to 1") + } + + @Test + fun `Divide 12 and 3 to get 4`() { + val t1 = Integer(12) + val t2 = Integer(3) + + val result = Divide(t1, t2).simplify(emptyMap()) + + assertEquals(Integer(4), result.to, "12 / 3 should be equal to 4") + } + + @Test + fun `Divide 1 and 2 to get float`() { + val t1 = Integer(1) + val t2 = Integer(2) + + val result = Divide(t1, t2).simplify(emptyMap()) + + assertTrue(equivalent(result.to, Float(0.5f), emptyMap()), "1 / 2 should be equal to 0.5") + } } diff --git a/tests/prolog/builtins/ControlOperatorsTests.kt b/tests/prolog/builtins/ControlOperatorsTests.kt index 598d3fd..15ad926 100644 --- a/tests/prolog/builtins/ControlOperatorsTests.kt +++ b/tests/prolog/builtins/ControlOperatorsTests.kt @@ -8,7 +8,7 @@ import prolog.ast.logic.Fact import prolog.ast.logic.Rule import prolog.ast.terms.Atom import prolog.ast.terms.CompoundTerm -import prolog.ast.terms.Integer +import prolog.ast.arithmetic.Integer import prolog.ast.terms.Variable class ControlOperatorsTests { diff --git a/tests/prolog/builtins/UnificationTest.kt b/tests/prolog/builtins/UnificationTest.kt index 6c92206..6ca212a 100644 --- a/tests/prolog/builtins/UnificationTest.kt +++ b/tests/prolog/builtins/UnificationTest.kt @@ -3,7 +3,7 @@ package prolog.builtins import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.Test import prolog.ast.terms.Atom -import prolog.ast.terms.Integer +import prolog.ast.arithmetic.Integer import prolog.ast.terms.Variable class UnificationTest { diff --git a/tests/prolog/logic/ArithmeticTests.kt b/tests/prolog/logic/ArithmeticTests.kt index 81b5f67..83fa8c7 100644 --- a/tests/prolog/logic/ArithmeticTests.kt +++ b/tests/prolog/logic/ArithmeticTests.kt @@ -5,7 +5,8 @@ import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.RepeatedTest import org.junit.jupiter.api.Test import prolog.Substitutions -import prolog.ast.terms.Integer +import prolog.ast.arithmetic.Integer +import prolog.ast.arithmetic.Float import prolog.ast.terms.Term import prolog.ast.terms.Variable @@ -141,6 +142,19 @@ class ArithmeticTests { assertTrue(result[0].getOrNull()!!.isEmpty(), "1 + 2 should be equal to 3") } + @Test + fun `1,0 + 2,0 = 3,0`() { + val t1 = Float(1.0f) + val t2 = Float(2.0f) + val t3 = Float(3.0f) + + val result = plus(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "1.0 + 1.0 should already be equal to 2.0") + } + @Test fun `1_plus_2_is_not_4`() { val t1 = Integer(1) @@ -152,6 +166,17 @@ class ArithmeticTests { assertTrue(result.none(), "1 + 2 should not be equal to 4") } + @Test + fun `1,0 plus 2,0 is not 4,0`() { + val t1 = Float(1.0f) + val t2 = Float(2.0f) + val t3 = Float(4.0f) + + val result = plus(t1, t2, t3, emptyMap()) + + assertTrue(result.none(), "1.0 + 2.0 should not be equal to 4.0") + } + @Test fun `1_plus_2_is_variable`() { val t1 = Integer(1) @@ -165,6 +190,22 @@ class ArithmeticTests { assertEquals(Integer(3), result[0].getOrNull()!![t3], "X should be equal to 3") } + @Test + fun `1,0 plus 2,0 is variable`() { + val t1 = Float(1.0f) + val t2 = Float(2.0f) + val t3 = Variable("X") + + val result = plus(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "1.0 + 2.0 should be equal to X") + assertTrue(result[0].isSuccess, "Expected success") + + val subs = result[0].getOrNull()!! + + assertTrue(equivalent(Float(3.0f), subs[t3]!!, emptyMap()), "X should be equal to 3.0") + } + @Test fun `1_plus_variable_is_3`() { val t1 = Integer(1) @@ -178,6 +219,22 @@ class ArithmeticTests { assertEquals(Integer(2), result[0].getOrNull()!![t2], "X should be equal to 2") } + @Test + fun `1,0 plus variable is 3,0`() { + val t1 = Float(1.0f) + val t2 = Variable("X") + val t3 = Float(3.0f) + + val result = plus(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "1.0 + X should be equal to 3.0") + assertTrue(result[0].isSuccess, "Expected success") + + val subs = result[0].getOrNull()!! + + assertTrue(equivalent(Float(2.0f), subs[t2]!!, emptyMap()), "X should be equal to 2.0") + } + @Test fun variable_plus_2_is_3() { val t1 = Variable("X") @@ -191,6 +248,22 @@ class ArithmeticTests { assertEquals(Integer(1), result[0].getOrNull()!![t1], "X should be equal to 1") } + @Test + fun `variable plus 2,0 is 3,0`() { + val t1 = Variable("X") + val t2 = Float(2.0f) + val t3 = Float(3.0f) + + val result = plus(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "X + 2.0 should be equal to 3.0") + assertTrue(result[0].isSuccess, "Expected success") + + val subs = result[0].getOrNull()!! + + assertTrue(equivalent(Float(1.0f), subs[t1]!!, emptyMap()), "X should be equal to 1.0") + } + @Test fun `1 plus 2 is bound-to-3-var`() { val t1 = Integer(1) @@ -204,6 +277,19 @@ class ArithmeticTests { assertTrue(result[0].getOrNull()!!.isEmpty(), "t3 should not be rebound") } + @Test + fun `1,0 plus 2,0 is bound-to-3,0-var`() { + val t1 = Float(1.0f) + val t2 = Float(2.0f) + val t3 = Variable("X") + + val result = plus(t1, t2, t3, mapOf(Variable("X") to Float(3.0f))).toList() + + assertEquals(1, result.size, "1.0 + 2.0 should be equal to X") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "t3 should not be rebound") + } + @Test fun `1 plus 2 is bound-to-4-var`() { val t1 = Integer(1) @@ -215,6 +301,17 @@ class ArithmeticTests { assertTrue(result.none(), "1 + 2 should not be equal to X") } + @Test + fun `1,0 plus 2,0 is bound-to-4,0-var`() { + val t1 = Float(1.0f) + val t2 = Float(2.0f) + val t3 = Variable("X") + + val result = plus(t1, t2, t3, mapOf(t3 to Float(4.0f))) + + assertTrue(result.none(), "1.0 + 2.0 should not be equal to X") + } + @Test fun `1 plus bound-to-2-var is 3`() { val t1 = Integer(1) @@ -228,6 +325,19 @@ class ArithmeticTests { assertTrue(result[0].getOrNull()!!.isEmpty(), "t2 should not be rebound") } + @Test + fun `1,0 plus bound-to-2,0-var is 3,0`() { + val t1 = Float(1.0f) + val t2 = Variable("X") + val t3 = Float(3.0f) + + val result = plus(t1, t2, t3, mapOf(t2 to Float(2.0f))).toList() + + assertEquals(1, result.size, "1.0 + X should be equal to 3.0") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "t2 should not be rebound") + } + @Test fun `1 plus bound-to-2-var is not 4`() { val t1 = Integer(1) @@ -239,6 +349,17 @@ class ArithmeticTests { assertTrue(result.none(), "1 + X should not be equal to 4") } + @Test + fun `1,0 plus bound-to-2,0-var is not 4,0`() { + val t1 = Float(1.0f) + val t2 = Variable("X") + val t3 = Float(4.0f) + + val result = plus(t1, t2, t3, mapOf(t2 to Float(2.0f))) + + assertTrue(result.none(), "1.0 + X should not be equal to 4.0") + } + @Test fun `bound-to-1-var plus 2 is 3`() { val t1 = Variable("X") @@ -252,6 +373,19 @@ class ArithmeticTests { assertTrue(result[0].getOrNull()!!.none(), "t1 should not be rebound") } + @Test + fun `bound-to-1-var plus 2,0 is 3,0`() { + val t1 = Variable("X") + val t2 = Float(2.0f) + val t3 = Float(3.0f) + + val result = plus(t1, t2, t3, mapOf(t1 to Integer(1))).toList() + + assertEquals(1, result.size, "X + 2.0 should be equal to 3.0") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.none(), "t1 should not be rebound") + } + @Test fun `bound-to-1-var plus 2 is not 4`() { val t1 = Variable("X") @@ -263,6 +397,17 @@ class ArithmeticTests { assertTrue(result.none(), "X + 2 should not be equal to 4") } + @Test + fun `bound-to-1-var plus 2,0 is not 4,0`() { + val t1 = Variable("X") + val t2 = Float(2.0f) + val t3 = Float(4.0f) + + val result = plus(t1, t2, t3, mapOf(t1 to Integer(1))) + + assertTrue(result.none(), "X + 2.0 should not be equal to 4.0") + } + @Test fun `two unbound vars plus should throw`() { val t1 = Variable("X") @@ -294,6 +439,37 @@ class ArithmeticTests { assertTrue(equivalent(result[0].getOrThrow()[t3]!!, Integer(3), result[0].getOrNull()!!), "Z should be equal to 3") } + @Test + fun `bound-to-1,0-var plus bound-to-2,0-var is variable`() { + val t1 = Variable("X") + val t2 = Variable("Y") + val t3 = Variable("Z") + + val map: Substitutions = mapOf( + t1 to Float(1.0f), + t2 to Float(2.0f), + ) + + val result = plus(t1, t2, t3, map).toList() + + assertTrue(result.isNotEmpty(), "X + Y should be equal to Z") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(equivalent(result[0].getOrThrow()[t3]!!, Float(3.0f), result[0].getOrNull()!!), "Z should be equal to 3.0") + } + + @Test + fun `int + float is float`() { + val t1 = Integer(1) + val t2 = Float(2.0f) + val t3 = Variable("X") + + val result = plus(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(equivalent(result[0].getOrThrow()[t3]!!, Float(3.0f), result[0].getOrNull()!!), "X should be equal to 3.0") + } + @Test fun `1 times 1 is 1`() { val t1 = Integer(1) @@ -307,6 +483,19 @@ class ArithmeticTests { assertTrue(result[0].getOrNull()!!.isEmpty(), "1 * 1 should already be equal to 1") } + @Test + fun `1,0 times 1,0 is 1,0`() { + val t1 = Float(1.0f) + val t2 = Float(1.0f) + val t3 = Float(1.0f) + + val result = mul(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "1.0 * 1.0 should already be equal to 1.0") + } + @Test fun `1 times 2 is 2`() { val t1 = Integer(1) @@ -320,6 +509,19 @@ class ArithmeticTests { assertTrue(result[0].getOrNull()!!.isEmpty(), "1 * 2 should already be equal to 2") } + @Test + fun `1,0 times 2,0 is 2,0`() { + val t1 = Float(1.0f) + val t2 = Float(2.0f) + val t3 = Float(2.0f) + + val result = mul(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "1.0 * 2.0 should already be equal to 2.0") + } + @Test fun `2 times 3 is 6`() { val t1 = Integer(2) @@ -333,6 +535,19 @@ class ArithmeticTests { assertTrue(result[0].getOrNull()!!.isEmpty(), "2 * 3 should already be equal to 6") } + @Test + fun `2,0 times 3,0 is 6,0`() { + val t1 = Float(2.0f) + val t2 = Float(3.0f) + val t3 = Float(6.0f) + + val result = mul(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "2.0 * 3.0 should already be equal to 6.0") + } + @Test fun `2 times 3 is not 4`() { val t1 = Integer(2) @@ -344,6 +559,30 @@ class ArithmeticTests { assertTrue(result.none(), "2 * 3 should not be equal to 4") } + @Test + fun `2,0 times 3,0 is not 4,0`() { + val t1 = Float(2.0f) + val t2 = Float(3.0f) + val t3 = Float(4.0f) + + val result = mul(t1, t2, t3, emptyMap()) + + assertTrue(result.none(), "2.0 * 3.0 should not be equal to 4.0") + } + + @Test + fun `int times float is float`() { + val t1 = Integer(2) + val t2 = Float(3.0f) + val t3 = Variable("X") + + val result = mul(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(equivalent(result[0].getOrThrow()[t3]!!, Float(6.0f), result[0].getOrNull()!!), "X should be equal to 6.0") + } + @RepeatedTest(100) fun `random test for mul`() { val t1 = Integer((0..1000).random()) @@ -356,4 +595,18 @@ class ArithmeticTests { assertTrue(result[0].isSuccess, "Expected success") assertEquals(Integer(t1.value * t2.value), result[0].getOrNull()!![t3], "X should be equal to ${t1.value * t2.value}") } + + @RepeatedTest(100) + fun `random test for mul with floats`() { + val t1 = Float((0..1000).random().toFloat()) + val t2 = Float((0..1000).random().toFloat()) + val t3 = Variable("X") + + val result = mul(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isSuccess, "Expected success") + val subs = result[0].getOrNull()!! + assertTrue(equivalent(subs[t3]!!, Float(t1.value * t2.value), subs), "X should be equal to ${t1.value * t2.value}") + } } diff --git a/tests/prolog/logic/UnifyTest.kt b/tests/prolog/logic/UnifyTest.kt index 5e40c5e..c306081 100644 --- a/tests/prolog/logic/UnifyTest.kt +++ b/tests/prolog/logic/UnifyTest.kt @@ -4,7 +4,7 @@ import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import prolog.Substitutions -import prolog.ast.terms.Integer +import prolog.ast.arithmetic.Integer import prolog.ast.terms.Atom import prolog.ast.terms.Structure import prolog.ast.terms.Variable