Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Oct 9, 2024
1 parent 39d3f84 commit de3de54
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ai.hypergraph.kaliningraph.repair

import ai.hypergraph.kaliningraph.parsing.freeze
import ai.hypergraph.kaliningraph.parsing.noEpsilonOrNonterminalStubs
import ai.hypergraph.kaliningraph.parsing.noNonterminalStubs
import ai.hypergraph.kaliningraph.parsing.parseCFG

val s2pCFGStr = """
Expand Down Expand Up @@ -197,4 +198,5 @@ Yield_Expr -> Yield_Keyword | Yield_Keyword Yield_Arg
Yield_Arg -> From_Keyword Test | Testlist_Endcomma
"""

val vanillaS2PCFG by lazy { s2pCFGStr.parseCFG().noEpsilonOrNonterminalStubs.freeze() }
val vanillaS2PCFG by lazy { s2pCFGStr.parseCFG().noEpsilonOrNonterminalStubs.freeze() }
val vanillaS2PCFGWE by lazy { s2pCFGStr.parseCFG().noNonterminalStubs.freeze() }
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Grammars.shortS2PParikhMap
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.repair.vanillaS2PCFG
import ai.hypergraph.kaliningraph.repair.vanillaS2PCFGWE
import kotlin.test.*
import kotlin.time.*

Expand Down Expand Up @@ -366,14 +367,13 @@ class BarHillelTest {
*/
@Test
fun levenshteinBlanketTest() {
val gram = vanillaS2PCFG.noEpsilonOrNonterminalStubs
val origStr= "NAME = NAME . NAME ( [ NUMBER , NUMBER , NUMBER ] NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val levDist = 2
val levBall = makeLevFSA(toRepair, levDist)
val clock = TimeSource.Monotonic.markNow()

val s2pg = vanillaS2PCFG
val s2pg = vanillaS2PCFGWE
s2pg.fasterRepairSeq(toRepair, 1, 2).distinct()
.mapIndexedNotNull { i, it ->
val levDistance = levenshtein(origStr, it)
Expand All @@ -384,6 +384,7 @@ class BarHillelTest {
it
} else null
}.takeWhile { clock.elapsedNow().inWholeSeconds < 30 }.toList()
.also { assertTrue(it.isNotEmpty()) }
.also { println("Found ${it.size} minimal solutions using " +
"Probabilistic repair in ${clock.elapsedNow()}") }

Expand Down Expand Up @@ -439,20 +440,20 @@ class BarHillelTest {
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.testToyArith"
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.testEnumLBEquality"
*/
@Test
fun testToyArith() {
fun testEnumLBEquality() {
val prompt = ") ( (".tokenizeByWhitespace()
val overwrittenRepairs =
Grammars.toyArith.barHillelRepair(prompt, 3).toSet()
.also { println("Found ${it.size} overwritten repairs.") }
Grammars.dyck.barHillelRepair(prompt, 3).toSet()
.also { println("Found ${it.size} BH repairs.") }

val allTriples = Grammars.toyArith.solveSeq(List(3) { "_" })
.distinct().toSet().also { println("Found ${it.size} total triples.") }
val allTriples = Grammars.dyck.solveSeq(List(6) { "_" })
.distinct().toSet()
.filter { levenshtein(prompt, it.tokenizeByWhitespace()) <= 3 }.toSet()
.also { println("Found ${it.size} total triples.") }

val allTriplesMinusOverwritten = overwrittenRepairs - allTriples
allTriplesMinusOverwritten.forEach { println(it) }
println("Found ${allTriplesMinusOverwritten.size} non-overwritten triples.")
assertEquals(overwrittenRepairs, allTriples)
}
}

0 comments on commit de3de54

Please sign in to comment.