diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt index 44d86ef7..217bd929 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt @@ -29,19 +29,16 @@ fun CFG.barHillelRepair(prompt: List<Σᐩ>, distance: Int) = // https://browse.arxiv.org/pdf/2209.06809.pdf#page=5 fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG { var clock = TimeSource.Monotonic.markNow() - val nts = mutableSetOf("START") - fun Σᐩ.isSyntheticNT() = - first() == '[' && last() == ']' && count { it == '~' } == 2 - fun Iterable.filterRHSInNTS() = - asSequence().filter { (_, rhs) -> rhs.all { !it.isSyntheticNT() || it in nts } } + val nts = mutableSetOf(listOf("START")) + fun List<Σᐩ>.isSyntheticNT() = size > 1 - val initFinal = - (fsa.init * fsa.final).map { (q, r) -> "START" to listOf("[$q~START~$r]") } - .filterRHSInNTS() + val initFinal = (fsa.init * fsa.final).map { (q, r) -> listOf("START") to listOf(listOf(q,"START",r)) } // For every production A → σ in P, for every (p, σ, q) ∈ Q × Σ × Q // such that δ(p, σ) = q we have the production [p, A, q] → σ in P′. - val unitProds = unitProdRules(fsa).onEach { (a, _) -> nts.add(a) } + val unitProds = unitProdRules2(fsa).map { (a, b) -> a.also { nts.add(it) } to b } + + fun List<Σᐩ>.toNT() = if (size == 1) first() else "[" + joinToString("~") + "]" // For each production A → BC in P, for every p, q, r ∈ Q, // we have the production [p,A,r] → [p,B,q] [q,C,r] in P′. @@ -73,13 +70,18 @@ fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG { // .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap) } .map { (a, b, c) -> val (p, q, r) = states[a] to states[b] to states[c] - "[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]") +// "[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]") + listOf(p, allsym[A], r).also { nts.add(it) } to listOf(listOf(p, allsym[B], q), listOf(q, allsym[C], r)) }.toList() - }.flatten().filterRHSInNTS() + }.flatten() println("Constructing ∩-grammar took: ${clock.elapsedNow()}") clock = TimeSource.Monotonic.markNow() - return (initFinal + binaryProds + unitProds).toSet().postProcess() + return (initFinal + binaryProds + unitProds) + .filter { (_, rhs) -> rhs.all { !it.isSyntheticNT() || it in nts } } + .map { (l, r) -> l.toNT() to r.map { it.toNT() } } + .toSet() + .postProcess() // .expandNonterminalStubs(origCFG = this@intersectLevFSAP) .also { println("Bar-Hillel construction took ${clock.elapsedNow()}") } } @@ -101,6 +103,12 @@ fun CFG.unitProdRules(fsa: FSA): List>> = // else null // } +fun CFG.unitProdRules2(fsa: FSA): List, List>>> = + (unitProductions * fsa.nominalize().flattenedTriples) + .filter { (_, σ: Σᐩ, arc) -> (arc.π2)(σ) } +// .map { (A, σ, arc) -> "[${arc.π1}~$A~${arc.π3}]" to listOf(σ) } + .map { (A, σ, arc) -> listOf(arc.π1, A, arc.π3) to listOf(listOf(σ)) } + fun CFG.expandNonterminalStubs(origCFG: CFG) = flatMap { // println("FM: $it / ${it.RHS.first()} / ${it.RHS.first().isNonterminalStub()}") if (it.RHS.size != 1 || !it.RHS.first().isNonterminalStub()) listOf(it) diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt index df453709..e3ac4ce8 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt @@ -190,18 +190,19 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF if (parikhMap.size < fsa.width + fsa.height) throw Exception("WARNING: Parikh map size exceeded") var clock = TimeSource.Monotonic.markNow() - val nts = ConcurrentHashMap.newKeySet<Σᐩ>().apply { add("START") } + val nts = ConcurrentHashMap.newKeySet>().apply { add(listOf("START")) } val initFinal = - (fsa.init * fsa.final).map { (q, r) -> "START" to listOf("[$q~START~$r]") } + (fsa.init * fsa.final).map { (q, r) -> listOf("START") to listOf(listOf(q, "START", r)) } val transits = - fsa.Q.map { (q, a, r) -> "[$q~$a~$r]".also { nts.add(it) } to listOf(a) } + fsa.Q.map { (q, a, r) -> listOf(q, a, r).also { nts.add(it) } to listOf(listOf(a)) } // For every production A → σ in P, for every (p, σ, q) ∈ Q × Σ × Q // such that δ(p, σ) = q we have the production [p, A, q] → σ in P′. - val unitProds = unitProdRules(fsa) - .toSet().onEach { (a, _) -> nts.add(a) } + val unitProds = unitProdRules2(fsa) + .map { (a, b) -> a.also { nts.add(it) } to b } + .toSet() val ccClock = TimeSource.Monotonic.markNow() val compat: Array>> = computeNTCompat(this, fsa.levString) @@ -254,7 +255,8 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF .map { (a, b, c) -> if (MAX_PRODS < counter.incrementAndGet()) throw Exception("∩-grammar has too many productions! (>$MAX_PRODS)") val (p, q, r) = states[a] to states[b] to states[c] - "[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]") +// "[$p~${allsym[A]}~$r]".also { nts.add(listOf(p, allsym[A], r)) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]") + listOf(p, allsym[A], r).also { nts.add(it) } to listOf(listOf(p, allsym[B], q), listOf(q, allsym[C], r)) } }.toList() @@ -262,8 +264,8 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF println("Levenshtein-Parikh constraints eliminated $elimCounter productions in ${lpClock.elapsedNow()}") // !isSyntheticNT() === is START or a terminal - fun Σᐩ.isSyntheticNT() = - first() == '[' && length > 1 // && last() == ']' && count { it == '~' } == 2 + fun List<Σᐩ>.isSyntheticNT() = size > 1 + fun List<Σᐩ>.toNT() = if (size == 1) first() else "[" + joinToString("~") + "]" val totalProds = binaryProds.size + transits.size + unitProds.size + initFinal.size println("Constructed ∩-grammar with $totalProds productions in ${clock.elapsedNow()}") @@ -273,6 +275,7 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF // A production, e.g., * -> * [G], can be removed if the synthetic nonterminal [G] does not exist, i.e., // every instance of [G] -> * * was incompatible with the FSA, so the nonterminal [G] is "unproductive". .filter { (_, rhs) -> rhs.all { !it.isSyntheticNT() || it in nts } } + .map { (l, r) -> l.toNT() to r.map { it.toNT() } } .collect(Collectors.toSet()) .also { println("Eliminated ${totalProds - it.size} extra productions before normalization") } .jvmPostProcess(clock)