Skip to content

Commit

Permalink
chase down bug with MAX_TOKENS
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Oct 24, 2024
1 parent 68a9f85 commit 18dc689
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ open class FSA(open val Q: TSA, open val init: Set<Σᐩ>, open val final: Set<
}

val stateCoords: Sequence<STC> by lazy { states.map { it.coords().let { (i, j) -> Triple(stateMap[it]!!, i, j) } }.asSequence() }
var height = 0
var width = 0

val validTriples by lazy { stateCoords.let { it * it * it }.filter { it.isValidStateTriple() }.toList() }
val validPairs by lazy { stateCoords.let { it * it }.filter { it.isValidStatePair() }.toSet() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ fun makeLevFSA(
}

FSA(Q, initialStates, finalStates)
.also { it.height = dist; it.width = str.size }
// .nominalize()
.also { println("Levenshtein-${str.size}x$dist automaton has ${Q.size} arcs!") }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ParikhMap(val cfg: CFG, val size: Int, reconstruct: Boolean = true) {
}
}

fun genRanges(delta: Int = 2 * MAX_RADIUS + 1, n: Int = MAX_TOKENS) =
fun genRanges(delta: Int = 2 * MAX_RADIUS + 1, n: Int = MAX_TOKENS + MAX_RADIUS) =
(1..delta).map { margin ->
val range = (0..n).toList()
range.windowed(margin, 1).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,9 @@ fun BAutomaton.decodeDFA(
println("Took ${startTime.elapsedNow()} to decode ${deduped.size} trajectories")

return deduped
}
}

fun BAutomaton.decodeDFA(
dec: Map<Char, Σᐩ>, // Maps unicode characters back to strings because BAutomata uses Unicode
take: Int = 1000,
) = getFiniteStrings(take).map { it.map { dec[it]!! }.joinToString(" ") }
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.repair.minimizeFix
import ai.hypergraph.kaliningraph.types.*
import ai.hypergraph.kaliningraph.types.times
import java.util.concurrent.*
import java.util.stream.*
import kotlin.streams.*
import kotlin.time.Duration.Companion.minutes
Expand Down Expand Up @@ -165,6 +164,7 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA,
lbc: List<IntRange> = this.lengthBoundsCache
): CFG {
// if (fsa.Q.size < 650) throw Exception("FSA size was out of bounds")
if (parikhMap.size < fsa.width + fsa.height) throw Exception("WARNING: Parikh map size exceeded")
var clock = TimeSource.Monotonic.markNow()

val nts = ConcurrentHashMap.newKeySet<Σᐩ>().apply { add("START") }
Expand Down Expand Up @@ -200,8 +200,9 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA,
// This is a finer grained filter, but more expensive to compute, so we use the coarse filter first
fsa.obeys(it.π1, it.π2, it.π3, parikhMap)
}.toList().also {
val fraction = it.size.toDouble() / (fsa.states.size * nonterminals.size * fsa.states.size)
println("Fraction of valid triples: $fraction")
val candidates = (fsa.states.size * nonterminals.size * fsa.states.size)
val fraction = it.size.toDouble() / candidates
println("Fraction of valid triples: ${it.size}/$candidates$fraction")
}.forEach { ct2[it.π11][it.π3][it.π21] = true }
println("Precomputed LP constraints in ${ctClock.elapsedNow()}")

Expand Down
22 changes: 22 additions & 0 deletions src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,28 @@ class WFSATest {
}.also { println("Decoding ${it.value.size} repairs took ${it.duration}") }
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.automata.WFSATest.testRepairMembership"
*/
@Test
fun testRepairMembership() {
val toRepair = "if STRING in NAME : return [ NEWLINE"
val groundTr = "if STRING in NAME : return NUMBER NEWLINE"
println(groundTr in vanillaS2PCFG.language)
val radius = 1
val fsa = makeLevFSA(toRepair, radius)
val gram = vanillaS2PCFG.run { jvmIntersectLevFSAP(fsa, parikhMap) }
val pt = gram.toPTree()

pt.toDFA(true)!!.decodeDFA(pt.termDict).toSet()
.also { assertTrue(it.isNotEmpty()) }.onEach {
assertTrue(fsa.recognizes(it) && it in vanillaS2PCFG.language)
println(levenshteinAlign(toRepair, it).paintANSIColors())
}.also { assertEquals(pt.sampleStrWithoutReplacement().toSet(), it) }

assertTrue(groundTr in gram.language)
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.automata.WFSATest.testBijection"
*/
Expand Down

0 comments on commit 18dc689

Please sign in to comment.