feat: Arithmetic

This commit is contained in:
Tibo De Peuter 2025-04-11 21:11:59 +02:00
parent e73e5cbfc8
commit ac55ed4c64
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
6 changed files with 147 additions and 107 deletions

View file

@ -2,6 +2,7 @@ package prolog.ast.logic
import prolog.logic.Substituted
interface Provable {
/**
* Proves the current [Provable] instance.

View file

@ -29,4 +29,12 @@ data class Integer(val value: Int): Term, Expression {
operator fun minus(other: Integer): Integer {
return Integer(value - other.value)
}
operator fun times(other: Integer): Integer {
return Integer(value * other.value)
}
operator fun div(other: Integer): Integer {
return Integer(value / other.value)
}
}

View file

@ -96,7 +96,7 @@ open class Add(private val expr1: Expression, private val expr2: Expression) :
override fun evaluate(subs: Substituted): Pair<Term, Substituted> {
val result = Variable("Result")
val map = plus(expr1, expr2, result, subs)
return result.evaluate(map.first())
return result.evaluate(map.first().getOrThrow())
}
}
@ -108,7 +108,7 @@ open class Subtract(private val expr1: Expression, private val expr2: Expression
override fun evaluate(subs: Substituted): Pair<Term, Substituted> {
val result = Variable("Result")
val map = plus(expr2, result, expr1, subs)
return result.evaluate(map.first())
return result.evaluate(map.first().getOrThrow())
}
}
@ -121,7 +121,7 @@ class Multiply(private val expr1: Expression, private val expr2: Expression) :
override fun evaluate(subs: Substituted): Pair<Term, Substituted> {
val result = Variable("Result")
val map = mul(expr1, expr2, result, subs)
return result.evaluate(map.first())
return result.evaluate(map.first().getOrThrow())
}
}

View file

@ -46,22 +46,26 @@ fun between(
* @throws IllegalArgumentException the domain error not_less_than_zero if called with a negative integer.
* E.g. succ(X, 0) fails silently and succ(X, -1) raises a domain error.125
*/
fun succ(term1: Expression, term2: Expression, subs: Substituted): Sequence<Substituted> {
fun succ(term1: Expression, term2: Expression, subs: Substituted): Sequence<Result<Substituted>> {
if (term2 is Integer) {
require(term2.value >= 0) { "Domain error: not_less_than_zero" }
}
val result = plus(term1, Integer(1), term2, subs)
// If term1 is a variable, we need to check if it is bound to a negative integer
return sequence { result.forEach { newSubs ->
val t1 = applySubstitution(term1, newSubs)
if (t1 is Variable && t1.alias().isPresent) {
val e1 = t1.evaluate(subs)
if(e1.first is Integer && (e1.first as Integer).value < 0) {
return@sequence
return sequence {
result.forEach { newSubs ->
if (newSubs.isSuccess) {
val t1 = applySubstitution(term1, newSubs.getOrNull()!!)
if (t1 is Variable && t1.alias().isPresent) {
val e1 = t1.evaluate(subs)
if (e1.first is Integer && (e1.first as Integer).value < 0) {
return@sequence
}
}
}
yield(newSubs)
}
yield(newSubs)
}}
}
}
/**
@ -69,82 +73,63 @@ fun succ(term1: Expression, term2: Expression, subs: Substituted): Sequence<Subs
*
* At least two of the three arguments must be instantiated to integers.
*/
fun plus(term1: Expression, term2: Expression, term3: Expression, subs: Substituted): Sequence<Substituted> = sequence {
fun plus(term1: Expression, term2: Expression, term3: Expression, subs: Substituted): Sequence<Result<Substituted>> =
operate(term1, term2, term3, subs, Integer::plus, Integer::minus)
fun mul(term1: Expression, term2: Expression, term3: Expression, subs: Substituted): Sequence<Result<Substituted>> =
operate(term1, term2, term3, subs, Integer::times, Integer::div)
fun operate(
term1: Expression,
term2: Expression,
term3: Expression,
subs: Substituted,
op: (Integer, Integer) -> Integer,
inverseOp: (Integer, Integer) -> Integer
): Sequence<Result<Substituted>> = sequence {
val t1 = applySubstitution(term1, subs)
val t2 = applySubstitution(term2, subs)
val t3 = applySubstitution(term3, subs)
// At least two arguments must be Integers
when {
nonvariable(t1) && nonvariable(t2) && nonvariable(t3) -> {
val e1 = t1.evaluate(subs)
val e2 = t2.evaluate(subs)
val e3 = t3.evaluate(subs)
val int3Value = e1.first as Integer + e2.first as Integer
val int3Value = op(e1.first as Integer, e2.first as Integer)
if (int3Value == e3.first as Integer) {
yield(e1.second + e2.second + e3.second)
yield(Result.success(e1.second + e2.second + e3.second))
}
}
nonvariable(t1) && nonvariable(t2) && variable(t3) -> {
val e1 = t1.evaluate(subs)
val e2 = t2.evaluate(subs)
val int3Value = e1.first as Integer + e2.first as Integer
val int3Value = op(e1.first as Integer, e2.first as Integer)
val int3 = t3 as Variable
int3.bind(int3Value)
yield(mapOf(int3 to int3Value) + e1.second + e2.second)
yield(Result.success(mapOf(int3 to int3Value) + e1.second + e2.second))
}
nonvariable(t1) && variable(t2) && nonvariable(t3) -> {
val e1 = t1.evaluate(subs)
((nonvariable(t1) && variable(t2)) || (variable(t1) && nonvariable(t2))) && nonvariable(t3) -> {
val t = if (nonvariable(t1)) t2 else t1
val e = if (nonvariable(t1)) t1.evaluate(subs) else t2.evaluate(subs)
val e3 = t3.evaluate(subs)
val int2Value = e3.first as Integer - e1.first as Integer
val int2 = t2 as Variable
int2.bind(int2Value)
yield(mapOf(int2 to int2Value) + e1.second + e3.second)
val value = inverseOp(e3.first as Integer, e.first as Integer)
val int = t as Variable
int.bind(value)
yield(Result.success(mapOf(int to value) + e.second + e3.second))
}
variable(t1) && nonvariable(t2) && nonvariable(t3) -> {
val e2 = t2.evaluate(subs)
val e3 = t3.evaluate(subs)
val int1Value = e3.first as Integer - e2.first as Integer
val int1 = t1 as Variable
int1.bind(int1Value)
yield(mapOf(int1 to int1Value) + e2.second + e3.second)
}
else -> {
throw IllegalArgumentException("At least two arguments must be instantiated to integers")
yield(Result.failure(IllegalArgumentException("At least two arguments must be instantiated to integers")))
}
}
}
/**
* Recursive implementation of the multiply operator, logical programming-wise.
*/
fun mul(term1: Expression, term2: Expression, term3: Expression, subs: Substituted): Sequence<Substituted> = sequence {
val t1 = applySubstitution(term1, subs)
val t2 = applySubstitution(term2, subs)
val t3 = applySubstitution(term3, subs)
// Base case
if (equivalent(t2, Integer(0))) {
yieldAll(Is(t3, Integer(0)).prove(subs))
}
// Recursive case
try {
val decremented = Variable("Decremented")
succ(decremented, t2, subs).forEach { decrementMap ->
val multiplied = Variable("Multiplied")
mul(t1, decremented, multiplied, subs + decrementMap).forEach { multipliedMap ->
yieldAll(plus(t1, multiplied, t3, subs + decrementMap + multipliedMap))
}
}
} catch(_: Exception) {
}
}
// TODO divmod
// TODO nth_integer_root_and_remainder