diff --git a/src/prolog/ast/logic/Provable.kt b/src/prolog/ast/logic/Provable.kt index 7bc4685..c587f39 100644 --- a/src/prolog/ast/logic/Provable.kt +++ b/src/prolog/ast/logic/Provable.kt @@ -2,6 +2,7 @@ package prolog.ast.logic import prolog.logic.Substituted + interface Provable { /** * Proves the current [Provable] instance. diff --git a/src/prolog/ast/terms/Integer.kt b/src/prolog/ast/terms/Integer.kt index 210566f..f3f9fcf 100644 --- a/src/prolog/ast/terms/Integer.kt +++ b/src/prolog/ast/terms/Integer.kt @@ -29,4 +29,12 @@ data class Integer(val value: Int): Term, Expression { operator fun minus(other: Integer): Integer { return Integer(value - other.value) } + + operator fun times(other: Integer): Integer { + return Integer(value * other.value) + } + + operator fun div(other: Integer): Integer { + return Integer(value / other.value) + } } diff --git a/src/prolog/builtins/arithmeticOperators.kt b/src/prolog/builtins/arithmeticOperators.kt index b78020c..4b2ac2f 100644 --- a/src/prolog/builtins/arithmeticOperators.kt +++ b/src/prolog/builtins/arithmeticOperators.kt @@ -96,7 +96,7 @@ open class Add(private val expr1: Expression, private val expr2: Expression) : override fun evaluate(subs: Substituted): Pair { val result = Variable("Result") val map = plus(expr1, expr2, result, subs) - return result.evaluate(map.first()) + return result.evaluate(map.first().getOrThrow()) } } @@ -108,7 +108,7 @@ open class Subtract(private val expr1: Expression, private val expr2: Expression override fun evaluate(subs: Substituted): Pair { val result = Variable("Result") val map = plus(expr2, result, expr1, subs) - return result.evaluate(map.first()) + return result.evaluate(map.first().getOrThrow()) } } @@ -121,7 +121,7 @@ class Multiply(private val expr1: Expression, private val expr2: Expression) : override fun evaluate(subs: Substituted): Pair { val result = Variable("Result") val map = mul(expr1, expr2, result, subs) - return result.evaluate(map.first()) + return result.evaluate(map.first().getOrThrow()) } } diff --git a/src/prolog/logic/arithmetic.kt b/src/prolog/logic/arithmetic.kt index dd00ba9..979d8f3 100644 --- a/src/prolog/logic/arithmetic.kt +++ b/src/prolog/logic/arithmetic.kt @@ -46,22 +46,26 @@ fun between( * @throws IllegalArgumentException the domain error not_less_than_zero if called with a negative integer. * E.g. succ(X, 0) fails silently and succ(X, -1) raises a domain error.125 */ -fun succ(term1: Expression, term2: Expression, subs: Substituted): Sequence { +fun succ(term1: Expression, term2: Expression, subs: Substituted): Sequence> { if (term2 is Integer) { require(term2.value >= 0) { "Domain error: not_less_than_zero" } } val result = plus(term1, Integer(1), term2, subs) // If term1 is a variable, we need to check if it is bound to a negative integer - return sequence { result.forEach { newSubs -> - val t1 = applySubstitution(term1, newSubs) - if (t1 is Variable && t1.alias().isPresent) { - val e1 = t1.evaluate(subs) - if(e1.first is Integer && (e1.first as Integer).value < 0) { - return@sequence + return sequence { + result.forEach { newSubs -> + if (newSubs.isSuccess) { + val t1 = applySubstitution(term1, newSubs.getOrNull()!!) + if (t1 is Variable && t1.alias().isPresent) { + val e1 = t1.evaluate(subs) + if (e1.first is Integer && (e1.first as Integer).value < 0) { + return@sequence + } + } } + yield(newSubs) } - yield(newSubs) - }} + } } /** @@ -69,82 +73,63 @@ fun succ(term1: Expression, term2: Expression, subs: Substituted): Sequence = sequence { +fun plus(term1: Expression, term2: Expression, term3: Expression, subs: Substituted): Sequence> = + operate(term1, term2, term3, subs, Integer::plus, Integer::minus) + +fun mul(term1: Expression, term2: Expression, term3: Expression, subs: Substituted): Sequence> = + operate(term1, term2, term3, subs, Integer::times, Integer::div) + +fun operate( + term1: Expression, + term2: Expression, + term3: Expression, + subs: Substituted, + op: (Integer, Integer) -> Integer, + inverseOp: (Integer, Integer) -> Integer +): Sequence> = sequence { val t1 = applySubstitution(term1, subs) val t2 = applySubstitution(term2, subs) val t3 = applySubstitution(term3, subs) - // At least two arguments must be Integers when { nonvariable(t1) && nonvariable(t2) && nonvariable(t3) -> { val e1 = t1.evaluate(subs) val e2 = t2.evaluate(subs) val e3 = t3.evaluate(subs) - val int3Value = e1.first as Integer + e2.first as Integer + val int3Value = op(e1.first as Integer, e2.first as Integer) if (int3Value == e3.first as Integer) { - yield(e1.second + e2.second + e3.second) + yield(Result.success(e1.second + e2.second + e3.second)) } } + nonvariable(t1) && nonvariable(t2) && variable(t3) -> { val e1 = t1.evaluate(subs) val e2 = t2.evaluate(subs) - val int3Value = e1.first as Integer + e2.first as Integer + val int3Value = op(e1.first as Integer, e2.first as Integer) val int3 = t3 as Variable int3.bind(int3Value) - yield(mapOf(int3 to int3Value) + e1.second + e2.second) + yield(Result.success(mapOf(int3 to int3Value) + e1.second + e2.second)) } - nonvariable(t1) && variable(t2) && nonvariable(t3) -> { - val e1 = t1.evaluate(subs) + + ((nonvariable(t1) && variable(t2)) || (variable(t1) && nonvariable(t2))) && nonvariable(t3) -> { + val t = if (nonvariable(t1)) t2 else t1 + val e = if (nonvariable(t1)) t1.evaluate(subs) else t2.evaluate(subs) val e3 = t3.evaluate(subs) - val int2Value = e3.first as Integer - e1.first as Integer - val int2 = t2 as Variable - int2.bind(int2Value) - yield(mapOf(int2 to int2Value) + e1.second + e3.second) + val value = inverseOp(e3.first as Integer, e.first as Integer) + val int = t as Variable + int.bind(value) + yield(Result.success(mapOf(int to value) + e.second + e3.second)) } - variable(t1) && nonvariable(t2) && nonvariable(t3) -> { - val e2 = t2.evaluate(subs) - val e3 = t3.evaluate(subs) - val int1Value = e3.first as Integer - e2.first as Integer - val int1 = t1 as Variable - int1.bind(int1Value) - yield(mapOf(int1 to int1Value) + e2.second + e3.second) - } else -> { - throw IllegalArgumentException("At least two arguments must be instantiated to integers") + yield(Result.failure(IllegalArgumentException("At least two arguments must be instantiated to integers"))) } } } -/** - * Recursive implementation of the multiply operator, logical programming-wise. - */ -fun mul(term1: Expression, term2: Expression, term3: Expression, subs: Substituted): Sequence = sequence { - val t1 = applySubstitution(term1, subs) - val t2 = applySubstitution(term2, subs) - val t3 = applySubstitution(term3, subs) - - // Base case - if (equivalent(t2, Integer(0))) { - yieldAll(Is(t3, Integer(0)).prove(subs)) - } - - // Recursive case - try { - val decremented = Variable("Decremented") - succ(decremented, t2, subs).forEach { decrementMap -> - val multiplied = Variable("Multiplied") - mul(t1, decremented, multiplied, subs + decrementMap).forEach { multipliedMap -> - yieldAll(plus(t1, multiplied, t3, subs + decrementMap + multipliedMap)) - } - } - } catch(_: Exception) { - } -} - // TODO divmod // TODO nth_integer_root_and_remainder \ No newline at end of file diff --git a/tests/prolog/builtins/ArithmeticOperatorsTests.kt b/tests/prolog/builtins/ArithmeticOperatorsTests.kt index 77c6a10..b27b222 100644 --- a/tests/prolog/builtins/ArithmeticOperatorsTests.kt +++ b/tests/prolog/builtins/ArithmeticOperatorsTests.kt @@ -418,9 +418,8 @@ class ArithmeticOperatorsTests { val t1 = Integer(2) val t2 = Integer(3) - val result = Multiply(t1, t2).evaluate(emptyMap()).toList() + val result = Multiply(t1, t2).evaluate(emptyMap()) - assertEquals(1, result.size, "There should only be one solution") - assertEquals(Integer(6), result, "2 * 3 should be equal to 6") + assertEquals(Integer(6), result.first, "2 * 3 should be equal to 6") } } diff --git a/tests/prolog/logic/ArithmeticTests.kt b/tests/prolog/logic/ArithmeticTests.kt index fdf4c68..d79e3b7 100644 --- a/tests/prolog/logic/ArithmeticTests.kt +++ b/tests/prolog/logic/ArithmeticTests.kt @@ -2,6 +2,7 @@ package prolog.logic import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.RepeatedTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import prolog.ast.terms.Integer @@ -53,10 +54,11 @@ class ArithmeticTests { val expected = Integer(1) - val result = succ(t1, expected, emptyMap()) + val result = succ(t1, expected, emptyMap()).toList() - assertTrue(result.any(), "Expected 0 + 1 to be equal to 1") - assertTrue(result.first().isEmpty(), "Expected no substitutions") + assertEquals(1, result.size, "Expected 0 + 1 to be equal to 1") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "Expected no substitutions") } @Test @@ -65,10 +67,11 @@ class ArithmeticTests { val expected = Integer(2) - val result = succ(t1, expected, emptyMap()) + val result = succ(t1, expected, emptyMap()).toList() - assertTrue(result.any(), "Expected 1 + 1 to be equal to 2") - assertTrue(result.first().isEmpty(), "Expected no substitutions") + assertEquals(1, result.size, "Expected 1 + 1 to be equal to 2") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "Expected no substitutions") } @Test @@ -80,8 +83,9 @@ class ArithmeticTests { val result = succ(t1, t2, emptyMap()).toList() - assertTrue(result.isNotEmpty(), "Expected 1 + 1 to be equal to X") - assertTrue(result[0][t2] == expected, "Expected X to be equal to 2") + assertEquals(1, result.size, "Expected 1 + 1 to be equal to X") + assertTrue(result[0].isSuccess, "Expected success") + assertEquals(expected, result[0].getOrNull()!![t2], "Expected X to be equal to 2") } @Test @@ -91,10 +95,11 @@ class ArithmeticTests { t1.bind(Integer(1)) - val result = succ(t1, t2, emptyMap()) + val result = succ(t1, t2, emptyMap()).toList() - assertTrue(result.any(), "Expected X + 1 to be equal to 2") - assertTrue(result.first().isEmpty(), "Expected no substitutions") + assertEquals(1, result.size, "Expected X + 1 to be equal to 2") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "Expected no substitutions") } @Test @@ -102,10 +107,11 @@ class ArithmeticTests { val t1 = Variable("X") val t2 = Integer(2) - val result = succ(t1, t2, mapOf(t1 to Integer(1))) + val result = succ(t1, t2, mapOf(t1 to Integer(1))).toList() - assertTrue(result.any(), "Expected X + 1 to be equal to 2") - assertTrue(result.first().isEmpty(), "Expected no substitutions") + assertEquals(1, result.size, "Expected X + 1 to be equal to 2") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "Expected no substitutions") } @Test @@ -117,8 +123,9 @@ class ArithmeticTests { val result = succ(t1, t2, emptyMap()).toList() - assertTrue(result.isNotEmpty(), "Expected X + 1 to be equal to Y") - assertTrue(result[0][t2] == Integer(2), "Expected Y to be equal to 2") + assertEquals(1, result.size, "Expected X + 1 to be equal to Y") + assertTrue(result[0].isSuccess, "Expected success") + assertEquals(Integer(2), result[0].getOrNull()!![t2], "Expected Y to be equal to 2") } @Test @@ -128,8 +135,9 @@ class ArithmeticTests { val result = succ(t1, t2, mapOf(t1 to Integer(1))).toList() - assertTrue(result.isNotEmpty(), "Expected X + 1 to be equal to Y") - assertTrue(result[0][t2] == Integer(2), "Expected Y to be equal to 2") + assertEquals(1, result.size, "Expected X + 1 to be equal to Y") + assertTrue(result[0].isSuccess, "Expected success") + assertEquals(Integer(2), result[0].getOrNull()!![t2], "Expected Y to be equal to 2") } @Test @@ -148,10 +156,11 @@ class ArithmeticTests { val t2 = Integer(2) val t3 = Integer(3) - val result = plus(t1, t2, t3, emptyMap()) + val result = plus(t1, t2, t3, emptyMap()).toList() - assertTrue(result.any(), "1 + 2 should be equal to 3") - assertTrue(result.first().isEmpty(), "1 + 2 should be equal to 3") + assertEquals(1, result.size, "1 + 2 should be equal to 3") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "1 + 2 should be equal to 3") } @Test @@ -173,8 +182,9 @@ class ArithmeticTests { val result = plus(t1, t2, t3, emptyMap()).toList() - assertTrue(result.isNotEmpty(), "1 + 2 should be equal to X") - assertTrue(equivalent(result[0][t3]!!, Integer(3)), "X should be equal to 3") + assertEquals(1, result.size, "1 + 2 should be equal to X") + assertTrue(result[0].isSuccess, "Expected success") + assertEquals(Integer(3), result[0].getOrNull()!![t3], "X should be equal to 3") } @Test @@ -185,8 +195,9 @@ class ArithmeticTests { val result = plus(t1, t2, t3, emptyMap()).toList() - assertTrue(result.isNotEmpty(), "1 + X should be equal to 3") - assertTrue(equivalent(result[0][t2]!!, Integer(2)), "X should be equal to 2") + assertEquals(1, result.size, "1 + X should be equal to 3") + assertTrue(result[0].isSuccess, "Expected success") + assertEquals(Integer(2), result[0].getOrNull()!![t2], "X should be equal to 2") } @Test @@ -197,8 +208,9 @@ class ArithmeticTests { val result = plus(t1, t2, t3, emptyMap()).toList() - assertTrue(result.isNotEmpty(), "X + 2 should be equal to 3") - assertTrue(equivalent(result[0][t1]!!, Integer(1)), "X should be equal to 1") + assertEquals(1, result.size, "X + 2 should be equal to 3") + assertTrue(result[0].isSuccess, "Expected success") + assertEquals(Integer(1), result[0].getOrNull()!![t1], "X should be equal to 1") } @Test @@ -209,10 +221,11 @@ class ArithmeticTests { t3.bind(Integer(3)) - val result = plus(t1, t2, t3, emptyMap()) + val result = plus(t1, t2, t3, emptyMap()).toList() - assertTrue(result.any(), "1 + 2 should be equal to X") - assertTrue(result.first().isEmpty(), "t3 should not be rebound") + assertEquals(1, result.size, "1 + 2 should be equal to X") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "t3 should not be rebound") } @Test @@ -221,10 +234,11 @@ class ArithmeticTests { val t2 = Integer(2) val t3 = Variable("X") - val result = plus(t1, t2, t3, mapOf(Variable("X") to Integer(3))) + val result = plus(t1, t2, t3, mapOf(Variable("X") to Integer(3))).toList() - assertTrue(result.any(), "1 + 2 should be equal to X") - assertTrue(result.first().isEmpty(), "t3 should not be rebound") + assertEquals(1, result.size, "1 + 2 should be equal to X") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "t3 should not be rebound") } @Test @@ -248,10 +262,11 @@ class ArithmeticTests { t2.bind(Integer(2)) - val result = plus(t1, t2, t3, emptyMap()) + val result = plus(t1, t2, t3, emptyMap()).toList() - assertTrue(result.any(), "1 + X should be equal to 3") - assertTrue(result.first().isEmpty(), "t2 should not be rebound") + assertEquals(1, result.size, "1 + X should be equal to 3") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "t2 should not be rebound") } @Test @@ -275,10 +290,11 @@ class ArithmeticTests { t1.bind(Integer(1)) - val result = plus(t1, t2, t3, emptyMap()) + val result = plus(t1, t2, t3, emptyMap()).toList() - assertTrue(result.any(), "X + 2 should be equal to 3") - assertTrue(result.first().none(), "t1 should not be rebound") + assertEquals(1, result.size, "X + 2 should be equal to 3") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.none(), "t1 should not be rebound") } @Test @@ -300,9 +316,11 @@ class ArithmeticTests { val t2 = Variable("Y") val t3 = Integer(3) - assertThrows { - plus(t1, t2, t3, emptyMap()) - } + val result = plus(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isFailure, "Expected failure") + assertTrue(result[0].exceptionOrNull() is IllegalArgumentException, "Expected IllegalArgumentException") } @Test @@ -317,7 +335,8 @@ class ArithmeticTests { val result = plus(t1, t2, t3, emptyMap()).toList() assertTrue(result.isNotEmpty(), "X + Y should be equal to Z") - assertTrue(equivalent(result[0][t3]!!, Integer(3)), "Z should be equal to 3") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(equivalent(result[0].getOrThrow()[t3]!!, Integer(3)), "Z should be equal to 3") } @Test @@ -329,7 +348,21 @@ class ArithmeticTests { val result = mul(t1, t2, t3, emptyMap()).toList() assertEquals(1, result.size, "There should be one solution") - assertTrue(result[0].isEmpty(), "1 * 1 should already be equal to 1") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "1 * 1 should already be equal to 1") + } + + @Test + fun `1 times 2 is 2`() { + val t1 = Integer(1) + val t2 = Integer(2) + val t3 = Integer(2) + + val result = mul(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "1 * 2 should already be equal to 2") } @Test @@ -341,7 +374,8 @@ class ArithmeticTests { val result = mul(t1, t2, t3, emptyMap()).toList() assertEquals(1, result.size, "There should be one solution") - assertTrue(result[0].isEmpty(), "2 * 3 should already be equal to 6") + assertTrue(result[0].isSuccess, "Expected success") + assertTrue(result[0].getOrNull()!!.isEmpty(), "2 * 3 should already be equal to 6") } @Test @@ -354,4 +388,17 @@ class ArithmeticTests { assertTrue(result.none(), "2 * 3 should not be equal to 4") } + + @RepeatedTest(100) + fun `random test for mul`() { + val t1 = Integer((0..1000).random()) + val t2 = Integer((0..1000).random()) + val t3 = Variable("X") + + val result = mul(t1, t2, t3, emptyMap()).toList() + + assertEquals(1, result.size, "There should be one solution") + assertTrue(result[0].isSuccess, "Expected success") + assertEquals(Integer(t1.value * t2.value), result[0].getOrNull()!![t3], "X should be equal to ${t1.value * t2.value}") + } }