diff --git a/src/interpreter/Preprocessor.kt b/src/interpreter/Preprocessor.kt index b6e7656..d10b605 100644 --- a/src/interpreter/Preprocessor.kt +++ b/src/interpreter/Preprocessor.kt @@ -1,4 +1,93 @@ package interpreter -class Preprocessor { +import io.Logger +import prolog.ast.logic.Clause +import prolog.ast.logic.Fact +import prolog.ast.logic.LogicOperand +import prolog.ast.logic.Rule +import prolog.ast.terms.Atom +import prolog.ast.terms.Body +import prolog.ast.terms.Goal +import prolog.ast.terms.Head +import prolog.ast.terms.Structure +import prolog.ast.terms.Term +import prolog.builtins.Conjunction +import prolog.builtins.Cut +import prolog.builtins.Disjunction +import prolog.builtins.Fail +import prolog.builtins.False +import prolog.builtins.Not +import prolog.builtins.Query +import prolog.builtins.True + +/** + * Preprocessor for Prolog + * + * This class preprocesses Prolog code and applies various transformations such as recognizing builtins. + */ +open class Preprocessor { + /** + * Preprocesses the input Prolog code. + * + * @param input The already parsed Prolog code as a list of clauses. + * @return The preprocessed Prolog code as a list of clauses. + */ + fun preprocess(input: List): List { + return input.map { preprocess(it) } + } + + fun preprocess(input: Query): Query { + return Query(preprocess(input.query) as Goal) + } + + private fun preprocess(clause: Clause): Clause { + return when (clause) { + is Fact -> { + Fact(preprocess(clause.head) as Head) + } + is Rule -> { + Rule( + preprocess(clause.head) as Head, + preprocess(clause.body as Term) as Body + ) + } + else -> clause + } + } + + protected open fun preprocess(term: Term): Term { + val prepped = when (term) { + Atom("true") -> True + Structure(Atom("true"), emptyList()) -> True + Atom("false") -> False + Structure(Atom("false"), emptyList()) -> False + Atom("fail") -> Fail + Structure(Atom("fail"), emptyList()) -> Fail + Atom("!") -> Cut() + Structure(Atom("!"), emptyList()) -> Cut() + else -> { + when { + term is Structure && term.functor == ",/2" -> { + val args = term.arguments.map { preprocess(it) } + Conjunction(args[0] as LogicOperand, args[1] as LogicOperand) + } + term is Structure && term.functor == ";/2" -> { + val args = term.arguments.map { preprocess(it) } + Disjunction(args[0] as LogicOperand, args[1] as LogicOperand) + } + term is Structure && term.functor == "\\+/1" -> { + val args = term.arguments.map { preprocess(it) } + Not(args[0] as Goal) + } + else -> term + } + } + } + + if (prepped != term || prepped::class != term::class) { + Logger.debug("Preprocessed term: $term -> $prepped (${prepped::class})") + } + + return prepped + } } \ No newline at end of file diff --git a/src/prolog/builtins/arithmeticOperators.kt b/src/prolog/builtins/arithmeticOperators.kt index 2463313..7931e7d 100644 --- a/src/prolog/builtins/arithmeticOperators.kt +++ b/src/prolog/builtins/arithmeticOperators.kt @@ -152,8 +152,8 @@ class Between(private val expr1: Expression, private val expr2: Expression, priv require(e1.to is Integer && e2.to is Integer) { "Arguments must be integers" } - val v1 = e1.to as Integer - val v2 = e2.to as Integer + val v1 = e1.to + val v2 = e2.to return if (variable(e3.to, subs)) { between(v1, v2, e3.to as Variable).map { answer -> diff --git a/src/repl/Repl.kt b/src/repl/Repl.kt index 699ab86..52cce29 100644 --- a/src/repl/Repl.kt +++ b/src/repl/Repl.kt @@ -1,5 +1,6 @@ package repl +import interpreter.Preprocessor import io.Logger import io.Terminal import parser.ReplParser @@ -9,6 +10,7 @@ import prolog.Answers class Repl { private val io = Terminal() private val parser = ReplParser() + private val preprocessor = Preprocessor() fun start() { io.say("Prolog REPL. Type '^D' to quit.\n") @@ -23,7 +25,8 @@ class Repl { fun query(): Answers { val queryString = io.prompt("?-", { "" }) - val query = parser.parse(queryString) + val simpleQuery = parser.parse(queryString) + val query = preprocessor.preprocess(simpleQuery) return query.satisfy(emptyMap()) } @@ -46,7 +49,10 @@ class Repl { } when (command) { - ";" -> previous = iterator.next() + ";" -> { + previous = iterator.next() + io.say(prettyPrint(previous)) + } "a" -> return "." -> return "h" -> { @@ -55,8 +61,6 @@ class Repl { } } } - - io.say(prettyPrint(previous)) } io.say("\n") diff --git a/tests/interpreter/PreprocessorTests.kt b/tests/interpreter/PreprocessorTests.kt new file mode 100644 index 0000000..a97755a --- /dev/null +++ b/tests/interpreter/PreprocessorTests.kt @@ -0,0 +1,156 @@ +package interpreter + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import prolog.ast.terms.Atom +import prolog.ast.terms.CompoundTerm +import prolog.ast.terms.Term +import prolog.builtins.Conjunction +import prolog.builtins.Disjunction +import prolog.builtins.Cut +import prolog.builtins.Fail +import prolog.builtins.True +import prolog.builtins.Not + +class PreprocessorTests { + class OpenPreprocessor : Preprocessor() { + public override fun preprocess(input: Term): Term { + return super.preprocess(input) + } + } + + @Nested + class `Arithmetic operators` { + @Test + fun `evaluates to different`() { + assertEquals(1, 2) + } + } + + @Nested + class `Control operators` { + private var preprocessor = OpenPreprocessor() + + @Test + fun fail() { + val tests = mapOf( + Atom("fail") to Fail, + CompoundTerm(Atom("fail"), emptyList()) to Fail, + Atom("Fail") to Atom("Fail"), + CompoundTerm(Atom("Fail"), emptyList()) to CompoundTerm(Atom("Fail"), emptyList()), + CompoundTerm(Atom("fail"), listOf(Atom("a"))) to CompoundTerm(Atom("fail"), listOf(Atom("a"))), + CompoundTerm(Atom("fail"), listOf(Atom("fail"))) to CompoundTerm(Atom("fail"), listOf(Fail)) + ) + + for ((input, expected) in tests) { + val result = preprocessor.preprocess(input) + assertEquals(expected, result, "Expected preprocessed") + assertEquals(expected::class, result::class, "Expected same class") + } + } + + @Test + fun `true`() { + val tests = mapOf( + Atom("true") to True, + CompoundTerm(Atom("true"), emptyList()) to True, + Atom("True") to Atom("True"), + CompoundTerm(Atom("True"), emptyList()) to CompoundTerm(Atom("True"), emptyList()), + CompoundTerm(Atom("true"), listOf(Atom("a"))) to CompoundTerm(Atom("true"), listOf(Atom("a"))), + CompoundTerm(Atom("true"), listOf(Atom("true"))) to CompoundTerm(Atom("true"), listOf(True)) + ) + + for ((input, expected) in tests) { + val result = preprocessor.preprocess(input) + assertEquals(expected, result, "Expected preprocessed") + assertEquals(expected::class, result::class, "Expected same class") + } + } + + @Test + fun cut() { + val tests = mapOf( + Atom("!") to Cut(), + CompoundTerm(Atom("!"), emptyList()) to Cut(), + CompoundTerm(Atom("!"), listOf(Atom("a"))) to CompoundTerm(Atom("!"), listOf(Atom("a"))), + CompoundTerm(Atom("!"), listOf(Atom("!"))) to CompoundTerm(Atom("!"), listOf(Cut())) + ) + + for ((input, expected) in tests) { + val result = preprocessor.preprocess(input) + assertEquals(expected, result, "Expected preprocessed") + assertEquals(expected::class, result::class, "Expected same class") + } + } + + @Test + fun conjunction() { + val tests = mapOf( + CompoundTerm(Atom(","), listOf(Atom("a"), Atom("b"))) to Conjunction(Atom("a"), Atom("b")), + CompoundTerm(Atom(","), listOf(Atom("a"), Atom("b"), Atom("c"))) to CompoundTerm( + Atom(","), + listOf(Atom("a"), Atom("b"), Atom("c")) + ), + // Nested conjunctions + CompoundTerm( + Atom(","), + listOf(Atom("a"), CompoundTerm(Atom(","), listOf(Atom("b"), Atom("c")))) + ) to Conjunction(Atom("a"), Conjunction(Atom("b"), Atom("c"))), + ) + + for ((input, expected) in tests) { + val result = preprocessor.preprocess(input) + assertEquals(expected, result, "Expected preprocessed") + assertEquals(expected::class, result::class, "Expected same class") + } + } + + @Test + fun disjunction() { + val tests = mapOf( + CompoundTerm(Atom(";"), listOf(Atom("a"), Atom("b"))) to Disjunction(Atom("a"), Atom("b")), + CompoundTerm(Atom(";"), listOf(Atom("a"), Atom("b"), Atom("c"))) to CompoundTerm( + Atom(";"), + listOf(Atom("a"), Atom("b"), Atom("c")) + ), + // Nested disjunctions + CompoundTerm( + Atom(";"), + listOf(Atom("a"), CompoundTerm(Atom(";"), listOf(Atom("b"), Atom("c")))) + ) to Disjunction(Atom("a"), Disjunction(Atom("b"), Atom("c"))), + ) + + for ((input, expected) in tests) { + val result = preprocessor.preprocess(input) + assertEquals(expected, result, "Expected preprocessed") + assertEquals(expected::class, result::class, "Expected same class") + } + } + + @Test + fun not() { + val tests = mapOf( + CompoundTerm(Atom("\\+"), listOf(Atom("a"))) to Not(Atom("a")), + CompoundTerm(Atom("\\+"), listOf(Atom("a"), Atom("b"))) to CompoundTerm( + Atom("\\+"), + listOf(Atom("a"), Atom("b")) + ), + // Nested not + CompoundTerm( + Atom("foo"), + listOf( + Atom("bar"), + CompoundTerm(Atom("\\+"), listOf(CompoundTerm(Atom("\\+"), listOf(Atom("baz"))))) + ) + ) to CompoundTerm(Atom("foo"), listOf(Atom("bar"), Not(Not(Atom("baz"))))), + ) + + for ((input, expected) in tests) { + val result = preprocessor.preprocess(input) + assertEquals(expected, result, "Expected preprocessed") + assertEquals(expected::class, result::class, "Expected same class") + } + } + } +} diff --git a/tests/parser/OperatorParserTests.kt b/tests/parser/OperatorParserTests.kt new file mode 100644 index 0000000..e15e89d --- /dev/null +++ b/tests/parser/OperatorParserTests.kt @@ -0,0 +1,28 @@ +package parser + +import com.github.h0tk3y.betterParse.grammar.Grammar +import com.github.h0tk3y.betterParse.grammar.parseToEnd +import com.github.h0tk3y.betterParse.parser.Parser +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import parser.grammars.TermsGrammar +import prolog.ast.terms.Atom +import prolog.ast.terms.Operator +import prolog.ast.terms.Structure + +class OperatorParserTests { + class OperatorParser: TermsGrammar() { + override val rootParser: Parser by operator + } + + private var parser = OperatorParser() as Grammar + + @Test + fun `parse conjunction`() { + val input = "a, b" + + val result = parser.parseToEnd(input) + + assertEquals(Structure(Atom(","), listOf(Atom("a"), Atom("b"))), result, "Expected atom 'a, b'") + } +} \ No newline at end of file