Skip to content

Commit

Permalink
relax memory bottleneck
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Nov 13, 2024
1 parent 8560f01 commit 3e87340
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,16 @@ fun CFG.barHillelRepair(prompt: List<Σᐩ>, distance: Int) =
// https://browse.arxiv.org/pdf/2209.06809.pdf#page=5
fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG {
var clock = TimeSource.Monotonic.markNow()
val nts = mutableSetOf("START")
fun Σᐩ.isSyntheticNT() =
first() == '[' && last() == ']' && count { it == '~' } == 2
fun Iterable<Production>.filterRHSInNTS() =
asSequence().filter { (_, rhs) -> rhs.all { !it.isSyntheticNT() || it in nts } }
val nts = mutableSetOf(listOf("START"))
fun List<Σᐩ>.isSyntheticNT() = size > 1

val initFinal =
(fsa.init * fsa.final).map { (q, r) -> "START" to listOf("[$q~START~$r]") }
.filterRHSInNTS()
val initFinal = (fsa.init * fsa.final).map { (q, r) -> listOf("START") to listOf(listOf(q,"START",r)) }

// For every production A → σ in P, for every (p, σ, q) ∈ Q × Σ × Q
// such that δ(p, σ) = q we have the production [p, A, q] → σ in P′.
val unitProds = unitProdRules(fsa).onEach { (a, _) -> nts.add(a) }
val unitProds = unitProdRules2(fsa).map { (a, b) -> a.also { nts.add(it) } to b }

fun List<Σᐩ>.toNT() = if (size == 1) first() else "[" + joinToString("~") + "]"

// For each production A → BC in P, for every p, q, r ∈ Q,
// we have the production [p,A,r] → [p,B,q] [q,C,r] in P′.
Expand Down Expand Up @@ -73,13 +70,18 @@ fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG {
// .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap) }
.map { (a, b, c) ->
val (p, q, r) = states[a] to states[b] to states[c]
"[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]")
// "[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]")
listOf(p, allsym[A], r).also { nts.add(it) } to listOf(listOf(p, allsym[B], q), listOf(q, allsym[C], r))
}.toList()
}.flatten().filterRHSInNTS()
}.flatten()

println("Constructing ∩-grammar took: ${clock.elapsedNow()}")
clock = TimeSource.Monotonic.markNow()
return (initFinal + binaryProds + unitProds).toSet().postProcess()
return (initFinal + binaryProds + unitProds)
.filter { (_, rhs) -> rhs.all { !it.isSyntheticNT() || it in nts } }
.map { (l, r) -> l.toNT() to r.map { it.toNT() } }
.toSet()
.postProcess()
// .expandNonterminalStubs(origCFG = this@intersectLevFSAP)
.also { println("Bar-Hillel construction took ${clock.elapsedNow()}") }
}
Expand All @@ -101,6 +103,12 @@ fun CFG.unitProdRules(fsa: FSA): List<Pair<String, List<Σᐩ>>> =
// else null
// }

fun CFG.unitProdRules2(fsa: FSA): List<Pair<List<String>, List<List<Σᐩ>>>> =
(unitProductions * fsa.nominalize().flattenedTriples)
.filter { (_, σ: Σᐩ, arc) -> (arc.π2)(σ) }
// .map { (A, σ, arc) -> "[${arc.π1}~$A~${arc.π3}]" to listOf(σ) }
.map { (A, σ, arc) -> listOf(arc.π1, A, arc.π3) to listOf(listOf(σ)) }

fun CFG.expandNonterminalStubs(origCFG: CFG) = flatMap {
// println("FM: $it / ${it.RHS.first()} / ${it.RHS.first().isNonterminalStub()}")
if (it.RHS.size != 1 || !it.RHS.first().isNonterminalStub()) listOf(it)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,19 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF
if (parikhMap.size < fsa.width + fsa.height) throw Exception("WARNING: Parikh map size exceeded")
var clock = TimeSource.Monotonic.markNow()

val nts = ConcurrentHashMap.newKeySet<Σᐩ>().apply { add("START") }
val nts = ConcurrentHashMap.newKeySet<List<Σᐩ>>().apply { add(listOf("START")) }

val initFinal =
(fsa.init * fsa.final).map { (q, r) -> "START" to listOf("[$q~START~$r]") }
(fsa.init * fsa.final).map { (q, r) -> listOf("START") to listOf(listOf(q, "START", r)) }

val transits =
fsa.Q.map { (q, a, r) -> "[$q~$a~$r]".also { nts.add(it) } to listOf(a) }
fsa.Q.map { (q, a, r) -> listOf(q, a, r).also { nts.add(it) } to listOf(listOf(a)) }

// For every production A → σ in P, for every (p, σ, q) ∈ Q × Σ × Q
// such that δ(p, σ) = q we have the production [p, A, q] → σ in P′.
val unitProds = unitProdRules(fsa)
.toSet().onEach { (a, _) -> nts.add(a) }
val unitProds = unitProdRules2(fsa)
.map { (a, b) -> a.also { nts.add(it) } to b }
.toSet()

val ccClock = TimeSource.Monotonic.markNow()
val compat: Array<Array<Array<Boolean>>> = computeNTCompat(this, fsa.levString)
Expand Down Expand Up @@ -254,16 +255,17 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF
.map { (a, b, c) ->
if (MAX_PRODS < counter.incrementAndGet()) throw Exception("∩-grammar has too many productions! (>$MAX_PRODS)")
val (p, q, r) = states[a] to states[b] to states[c]
"[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]")
// "[$p~${allsym[A]}~$r]".also { nts.add(listOf(p, allsym[A], r)) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]")
listOf(p, allsym[A], r).also { nts.add(it) } to listOf(listOf(p, allsym[B], q), listOf(q, allsym[C], r))
}
}.toList()

val elimCounter = (validTriples.size * prods.size) - binaryProds.size
println("Levenshtein-Parikh constraints eliminated $elimCounter productions in ${lpClock.elapsedNow()}")

// !isSyntheticNT() === is START or a terminal
fun Σᐩ.isSyntheticNT() =
first() == '[' && length > 1 // && last() == ']' && count { it == '~' } == 2
fun List<Σᐩ>.isSyntheticNT() = size > 1
fun List<Σᐩ>.toNT() = if (size == 1) first() else "[" + joinToString("~") + "]"

val totalProds = binaryProds.size + transits.size + unitProds.size + initFinal.size
println("Constructed ∩-grammar with $totalProds productions in ${clock.elapsedNow()}")
Expand All @@ -273,6 +275,7 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF
// A production, e.g., * -> * [G], can be removed if the synthetic nonterminal [G] does not exist, i.e.,
// every instance of [G] -> * * was incompatible with the FSA, so the nonterminal [G] is "unproductive".
.filter { (_, rhs) -> rhs.all { !it.isSyntheticNT() || it in nts } }
.map { (l, r) -> l.toNT() to r.map { it.toNT() } }
.collect(Collectors.toSet())
.also { println("Eliminated ${totalProds - it.size} extra productions before normalization") }
.jvmPostProcess(clock)
Expand Down

0 comments on commit 3e87340

Please sign in to comment.