Skip to content

Commit

Permalink
optimize serializer and bump limit to 100 tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Oct 30, 2024
1 parent 492027d commit 0071449
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,35 +197,10 @@ val CFG.parikhMap: ParikhMap by cache {
val parikhMap = if (hashCode() in langCache)
ParikhMap.deserialize(this, langCache[hashCode()]!!)
else ParikhMap(this, MAX_TOKENS + 5)
println("Computed Parikh map in ${clock.elapsedNow()}")
println("Obtained Parikh map in ${clock.elapsedNow()}")
parikhMap
}

// Tracks the number of tokens a given nonterminal can represent
// e.g., a NT with a bound of 1..3 can parse { s: Σ^[1, 3] }
val CFG.lengthBounds: Map<Σᐩ, IntRange> by cache {
val clock = TimeSource.Monotonic.markNow()
val epsFree = noEpsilonOrNonterminalStubs.freeze()
val tpl = List(MAX_TOKENS + 5) { "_" }
val map =
epsFree.nonterminals.associateWith { -1..-1 }.toMutableMap()
epsFree.initPForestMat(tpl).seekFixpoint().diagonals.mapIndexed { idx, sets ->
sets
.first().keys
// .flatMap { it.map { it.key } }.toSet()
.forEach { nt ->
map[nt]?.let {
(if (it.first < 0) (idx + 1) else it.first)..(idx + 1)
}?.let { map[nt] = it }
}
}

println("Computed NT length bounds in ${clock.elapsedNow()}")
map
}

val CFG.lengthBoundsCache by cache { lengthBounds.let { lb -> nonterminals.map { lb[it] ?: 0..0 } } }

fun Π3A<STC>.isValidStateTriple(): Boolean {
fun Pair<Int, Int>.dominates(other: Pair<Int, Int>) =
first <= other.first && second <= other.second
Expand Down
40 changes: 21 additions & 19 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Parikh.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.repair.*
import ai.hypergraph.kaliningraph.tokenizeByWhitespace
import ai.hypergraph.kaliningraph.types.cache
import kotlin.jvm.JvmName
import kotlin.math.*

// Number of each terminal (necessary, possible)
typealias ParikhBounds = Map<Σᐩ, IntRange>
Expand Down Expand Up @@ -56,34 +53,25 @@ class ParikhMap(val cfg: CFG, val size: Int, reconstruct: Boolean = true) {
private val parikhMap: MutableMap<Int, ParikhBoundsMap> = mutableMapOf()
val parikhRangeMap: MutableMap<IntRange, ParikhBoundsMap> = mutableMapOf() // Parameterized Parikh map
val ntIdx = cfg.nonterminals.toList()
val ntLengthBounds: MutableList<IntRange> = mutableListOf()

companion object {
fun serialize(pm: ParikhMap): String =
serializePM(pm.parikhMap) + "\n\n====\n\n" +
pm.lengthBounds.entries.joinToString("\n") { (k, v) -> "$k ${v.joinToString(" ")}" }

fun serializePM(pm: Map<Int, ParikhBoundsMap>) =
pm.entries.joinToString("\n") { (k0, v0) ->
v0.entries.joinToString("\n") { (k1, v1) ->
v1.entries.joinToString("\n") { (k2, v2) ->
"$k0 $k1 $k2 ${v2.first} ${v2.last}"
}
pm.entries.joinToString("\n") { (k0: Int, v0: ParikhBoundsMap) ->
v0.entries.joinToString("\n") { (k1: String, v1: Map<Σᐩ, IntRange>) ->
"$k0 $k1 : " + v1.entries.joinToString(" ") { (k2, v2) -> "$k2 ${v2.first} ${v2.last}" }
}
}

fun deserializePM(str: String): Map<Int, ParikhBoundsMap> =
str.lines().filter { it.isNotBlank() }
.map { it.split(" ") }.groupBy { it[0].toInt() }.mapValues { (_, v0) ->
v0.groupBy { it[1] }.mapValues { (_, v1) ->
v1.map { it[2] to it[3].toInt()..it[4].toInt() }.toMap()
str.lines().map { it.split(" ") }.groupBy { it.first().toInt() }
.mapValues { (_, v) ->
v.map { it[1] to it.drop(3).chunked(3).map { it[0] to (it[1].toInt()..it[2].toInt()) }.toMap() }.toMap()
}
}.mapValues { (_, v0) ->
v0.mapValues { (_, v1) ->
v1.mapValues { (_, v2) ->
v2
}
}
}

fun deserialize(cfg: CFG, str: String): ParikhMap {
val pm = deserializePM(str.substringBefore("\n\n====\n\n"))
Expand All @@ -94,6 +82,7 @@ class ParikhMap(val cfg: CFG, val size: Int, reconstruct: Boolean = true) {
parikhMap.putAll(pm)
lengthBounds.putAll(lb)
populatePRMFromPM()
populateLengthBounds()
}
}

Expand All @@ -117,6 +106,18 @@ class ParikhMap(val cfg: CFG, val size: Int, reconstruct: Boolean = true) {
}
}

fun populateLengthBounds() {
// Compute the bounds for each nonterminal of the least to greatest index it appears in lengthBounds
// If it does not appear in lengthBounds, it is assumed to have bounds 0..0
val nts = cfg.nonterminals

ntLengthBounds.addAll(nts.associateWith { nt ->
lengthBounds.entries.filter { nt in it.value }.map { it.key }.ifEmpty { listOf(0) }.let { bounds ->
bounds.minOrNull()!!..bounds.maxOrNull()!!
}
}.let { lb -> nts.map { lb[it]!! } })
}

init {
if (reconstruct) {
val template = List(size) { "_" }
Expand All @@ -128,6 +129,7 @@ class ParikhMap(val cfg: CFG, val size: Int, reconstruct: Boolean = true) {
}

populatePRMFromPM()
populateLengthBounds()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,7 @@ val MAX_PRODS = 150_000_000

// We pass pm and lbc because cache often flushed forcing them to be reloaded
// and we know they will usually be the same for all calls to this function.
fun CFG.jvmIntersectLevFSAP(
fsa: FSA,
parikhMap: ParikhMap = this.parikhMap,
lbc: List<IntRange> = this.lengthBoundsCache
): CFG {
fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG {
// if (fsa.Q.size < 650) throw Exception("FSA size was out of bounds")
if (parikhMap.size < fsa.width + fsa.height) throw Exception("WARNING: Parikh map size exceeded")
var clock = TimeSource.Monotonic.markNow()
Expand Down Expand Up @@ -194,7 +190,7 @@ fun CFG.jvmIntersectLevFSAP(
// Checks whether the length bounds for the noterminal (i.e., the range of the number of terminals it can
// parse) is compatible with the range of path lengths across all paths connecting two states in an FSA.
// This is a coarse approximation, but is cheaper to compute, so it filters out most invalid triples.
lbc[it.π3].overlaps(
parikhMap.ntLengthBounds[it.π3].overlaps(
fsa.SPLP(it.π1, it.π2)
) &&
// Checks the Parikh map for compatibility between the CFG nonterminals and state pairs in the FSA.
Expand Down
Binary file modified src/jvmMain/resources/1566012639.cache.zip
Binary file not shown.

0 comments on commit 0071449

Please sign in to comment.