Skip to content

Commit

Permalink
use binary search and implement multi-edit pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Oct 28, 2024
1 parent ab5d42a commit ae4eba6
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.repair.MAX_RADIUS
import ai.hypergraph.kaliningraph.types.*

// https://en.wikipedia.org/wiki/Regular_grammar
Expand Down Expand Up @@ -45,18 +46,6 @@ fun pruneInactiveRules(cfg: CFG): CFG =
TODO("Identify and prune all nonterminals t generating" +
"a finite language rooted at t and disjoint from the upward closure.")

fun CFG.maxParsableFragment(tokens: List<String>, pad: Int = 3): Pair<Int, Int> =
((1..tokens.size).firstOrNull { i ->
val blocked =
tokens.mapIndexed { j, t -> if (j < i) t else "_" } + List(pad) { "_" }
// println(blocked)
blocked !in language
} ?: tokens.size) to ((2..tokens.size).firstOrNull { i ->
val blocked = List(pad) { "_" } +
tokens.mapIndexed { j, t -> if (tokens.size - i < j) t else "_" }
// println(blocked)
blocked !in language
}?.let { tokens.size - it } ?: 0)

// REL ⊂ CFL ⊂ CJL
operator fun REL.contains(s: Σᐩ): Bln = s in reg.asCFG.language
Expand Down
143 changes: 131 additions & 12 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Levenshtein.kt
Original file line number Diff line number Diff line change
Expand Up @@ -60,33 +60,55 @@ fun Σᐩ.unpackCoordinates() =
/** Uses nominal arc predicates. See [NOM] for denominalization. */
fun makeLevFSA(
str: List<Σᐩ>,
dist: Int,
digits: Int = (str.size * dist).toString().length,
bounds: Pair<Int, Int> = str.size to 0
maxRad: Int, // Maximum Levenshtein distance the automaton should accept
/**
* (x, y) where x is the first index where 1+ edit must have occurred already, and y
* is the last index where there is at least one more edit left to make in the string.
* We can use (x,y) to prune states representing trajectories which have spent their
* entire edit allocation (with provably one edit left to make) or which have made no
* edits so far (with provably at least one edit necessary) to reach a parsable state.
* See [maxParsableFragment] for how these bounds are proven.
*/
singleEditBounds: Pair<Int, Int> = str.size to 0,
/**
* Range provably containing two or more edits -- should be minimal for efficiency.
* We can use this to prune states representing trajectories which have 1 or fewer
* edits in their budget, but need at least 2+ to reach a final parsable state, or
* which have only used one edit out of their budget but must have made 2+ edits
* by this point in order to reach a parsable state. This proof is expensive to
* find but worthwhile for long strings. See [smallestRangeWithNoSingleEditRepair].
*/
// multiEditBounds: IntRange = 0 until str.size
digits: Int = (str.size * maxRad).toString().length,
): FSA =
(upArcs(str, dist, digits) +
diagArcs(str, dist, digits) +
str.mapIndexed { i, it -> rightArcs(i, dist, it, digits) }.flatten() +
str.mapIndexed { i, it -> knightArcs(i, dist, it, digits, str) }.flatten())
(upArcs(str, maxRad, digits) +
diagArcs(str, maxRad, digits) +
str.mapIndexed { i, it -> rightArcs(i, maxRad, it, digits) }.flatten() +
str.mapIndexed { i, it -> knightArcs(i, maxRad, it, digits, str) }.flatten())
.also {
println("Levenshtein-${str.size}x$dist automaton had ${it.size} arcs initially!")
println("Levenshtein-${str.size}x$maxRad automaton had ${it.size} arcs initially!")
}.filter { arc ->
listOf(arc.first.unpackCoordinates(), arc.third.unpackCoordinates())
.all { (i, j) -> (0 < j || i <= bounds.first) && (j < dist || i >= bounds.second - 2) }
.all { (i, j) ->
(0 < j || i <= singleEditBounds.first) // Prunes bottom right
&& (j < maxRad || i >= singleEditBounds.second - 2) // Prunes top left
// && (1 < j || i <= multiEditBounds.last + 2 || maxRad == 1) // Prunes bottom right
// && (j < maxRad - 1 || i > multiEditBounds.first - 3 || maxRad == 1) // Prunes top left
}
}
.let { Q ->
val initialStates = setOf("q_" + pd(0, digits).let { "$it/$it" })

val finalStates = mutableSetOf<String>()
Q.states.forEach {
val (i, j) = it.unpackCoordinates()
if ((str.size - i + j).absoluteValue <= dist) finalStates.add(it)
if ((str.size - i + j).absoluteValue <= maxRad) finalStates.add(it)
}

FSA(Q, initialStates, finalStates)
.also { it.height = dist; it.width = str.size; it.levString = str }
.also { it.height = maxRad; it.width = str.size; it.levString = str }
// .nominalize()
.also { println("Levenshtein-${str.size}x$dist automaton had ${Q.size} arcs finally!") }
.also { println("Levenshtein-${str.size}x$maxRad automaton had ${Q.size} arcs after pruning!") }
}

private fun pd(i: Int, digits: Int) = i.toString().padStart(digits, '0')
Expand Down Expand Up @@ -167,6 +189,103 @@ fun List<Π5<Int, Int, Σᐩ, Int, Int>>.postProc(digits: Int) =
"q_$a/$b" to s to "q_$d/$e"
}.toSet()

/**
* Levenshtein automata optimizations to identify ranges that must contain an edit to be parsable.
* These serve as proofs for the unreachability of certain states in the Levenshtein automaton.
* For example, if we know that a certain range must contain at least one to be parsable, then we
* have a proof that any states which have not yet made an edit after that range are unreachable,
* and states which have exhausted all their edits before that range are also unreachable.
*/

fun CFG.maxParsableFragmentL(tokens: List<String>, pad: Int = 3): Pair<Int, Int> =
((1..tokens.size).toList().firstOrNull { i ->
blockForward(tokens, i, pad) !in language
} ?: tokens.size) to ((2..tokens.size).firstOrNull { i ->
blockBackward(tokens, i, pad) !in language
}?.let { tokens.size - it } ?: 0)

fun blockForward(tokens: List<String>, i: Int, pad: Int = 3) =
tokens.mapIndexed { j, t -> if (j < i) t else "_" } + List(pad) { "_" }

fun blockBackward(tokens: List<String>, i: Int, pad: Int = 3) =
List(pad) { "_" } + tokens.mapIndexed { j, t -> if (tokens.size - i < j) t else "_" }

// Binary search for the max parsable fragment. Equivalent to the linear search, but faster
fun CFG.maxParsableFragmentB(tokens: List<String>, pad: Int = 3): Pair<Int, Int> =
((1..tokens.size).toList().binarySearch { i ->
val blocked = blockForward(tokens, i, pad)
val blockedInLang = blocked in language
// println(blocked.joinToString(" "))
if (blockedInLang) -1 else {
val blockedPrev = blockForward(tokens, i - 1, pad)
val blockedPrevInLang = i == 1 || blockedPrev in language
if (!blockedInLang && blockedPrevInLang) 0 else 1
}
}.let { if (it < 0) tokens.size else it + 1 }) to ((2..tokens.size).toList().binarySearch { i ->
val blocked = blockBackward(tokens, i, pad)
val blockedInLang = blocked in language
// println(blocked.joinToString(" "))
if (blockedInLang) -1 else {
val blockedPrev = blockBackward(tokens, i - 1, pad)
val blockedPrevInLang = i == 2 || blockedPrev in language
if (!blockedInLang && blockedPrevInLang) 0 else 1
}
}.let { if (it < 0) 0 else (tokens.size - it - 2).coerceAtLeast(0) })

fun maskEverythingButRange(tokens: List<String>, range: IntRange): List<String> =
tokens.mapIndexed { i, t -> if (i in range) t else "_" }

fun CFG.hasSingleEditRepair(tokens: List<String>, range: IntRange): Boolean =
maskEverythingButRange(tokens, range).let { premask ->
val toCheck = if (range.first < 0) List(-range.first) { "_" } + premask
else if (tokens.size <= range.last) premask + List(range.last - tokens.size) { "_" }
else premask

(maxOf(0, range.first) until minOf(tokens.size, range.last + 1)).any { i ->
toCheck.mapIndexed { j, t -> if (j == i) "_" else t }.also { println(it.joinToString(" ")) } in language
}
}

// Tries to shrink a multi-edit range until it has a single edit repair
fun CFG.tryToShrinkMultiEditRange(tokens: List<String>, range: IntRange): IntRange {
fun IntRange.tryToShrinkLeft(): IntRange {
var left = first + 1
while (left < last - 2 && !hasSingleEditRepair(tokens, left until last)) left++
return left until last
}

fun IntRange.tryToShrinkRight(): IntRange {
var right = last
while (first < right - 1 && !hasSingleEditRepair(tokens, first until right)) right--
return first until right
}

return range.tryToShrinkLeft().tryToShrinkRight()
}

fun CFG.smallestRangeWithNoSingleEditRepair(tokens: List<String>, stride: Int = MAX_RADIUS + 2): IntRange {
if (tokens.size < 30) return 0..tokens.size
else {
val rangeLen = (0.4 * tokens.size).toInt()
val indices = -stride until (tokens.size - rangeLen + stride) step stride
var rmin = 0..tokens.size
for (i in indices) {
println("Checking range $i..${i + rangeLen}")
val r = i until i + rangeLen
if (hasSingleEditRepair(tokens, r)) continue
println("Found multi-edit range $r")
val rmin1 = tryToShrinkMultiEditRange(tokens, r)
println("Shrunk to $rmin1")
if (rmin1.last - rmin1.first < rmin.last - rmin.first) {
rmin = rmin1
if (rmin.last - rmin.first < 0.2 * tokens.size && rmin != 0..tokens.size) return rmin
}
}

return rmin
}
}

fun allPairsLevenshtein(s1: Set<Σᐩ>, s2: Set<Σᐩ>) =
(s1 * s2).sumOf { (a, b) -> levenshtein(a, b) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ fun CFG.jvmIntersectLevFSAP(
}.toList().also {
val candidates = (fsa.states.size * nonterminals.size * fsa.states.size)
val fraction = it.size.toDouble() / candidates
println("Fraction of valid triples: ${it.size}/$candidates$fraction")
println("Fraction of valid LBH triples: ${it.size}/$candidates$fraction")
}.forEach { ct2[it.π11][it.π3][it.π21] = true }
println("Precomputed LP constraints in ${ctClock.elapsedNow()}")

Expand Down

0 comments on commit ae4eba6

Please sign in to comment.