Skip to content

Commit

Permalink
implement simple pruning for Levenshtein automaton
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jan 4, 2024
1 parent d589ca1 commit f68bcc2
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.repair.Patch
import ai.hypergraph.kaliningraph.repair.*
import ai.hypergraph.kaliningraph.types.*
import ai.hypergraph.kaliningraph.types.times
import kotlin.math.*

// Only accept states that are within radius dist of (strLen, 0)
Expand Down Expand Up @@ -57,9 +58,10 @@ fun makeLevFSA(
str: List<Σᐩ>,
dist: Int,
alphabet: Set<Σᐩ>,
digits: Int = (str.size * dist).toString().length
digits: Int = (str.size * dist).toString().length,
ceaDist: CEADist? = null
): FSA =
(upArcs(str, dist, alphabet, digits) +
(upArcs(str, dist, alphabet, digits, ceaDist) +
diagArcs(str, dist, alphabet, digits) +
str.mapIndexed { i, it -> rightArcs(i, dist, it, digits) }.flatten() +
str.mapIndexed { i, it -> knightArcs(i, dist, it, digits) }.flatten())
Expand All @@ -82,7 +84,11 @@ private fun pd(i: Int, digits: Int) = i.toString().padStart(digits, '0')

/**
TODO: upArcs and diagArcs are the most expensive operations taking ~O(2n|Σ|) to construct.
We can probably do much better by only creating arcs that are contextually probable.
Later, the Bar-Hillel construction creates a new production for every triple QxQxQ, so it
increases the size of generated grammar by (2n|Σ|)^3. For this to be tractable on real
world grammars and code snippets, we must prune the transitions aggressively by only
creating arcs that are contextually likely.
See: [ai.hypergraph.kaliningraph.repair.CEAProb]
*/

Expand All @@ -92,10 +98,14 @@ private fun pd(i: Int, digits: Int) = i.toString().padStart(digits, '0')
(q_i,j−1 -s→ q_i,j)∈δ
*/

fun upArcs(str: List<Σᐩ>, dist: Int, alphabet: Set<Σᐩ>, digits: Int): TSA =
fun upArcs(str: List<Σᐩ>, dist: Int, alphabet: Set<Σᐩ>, digits: Int, ceaDist: CEADist? = null): TSA =
((0..<str.size + dist).toSet() * (1..dist).toSet() * alphabet)
.filter { (i, _, s) -> str.size <= i || str[i] != s }
.filter { (i, j, _) -> i <= str.size || i - str.size < j }
.filter { (i, j, s) ->
if (ceaDist == null || j != 1) false
else s in (ceaDist.insLeft[str.getOrElse(i - 1) { "BOS" }] ?: setOf())
}
.map { (i, j, s) -> i to j - 1 to s to i to j }.postProc(digits)

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package ai.hypergraph.kaliningraph.repair

import kotlin.math.pow

val contextCSV by lazy { pythonContext.lines().readContextCSV() }
val contextCSV: CEADist by lazy { pythonContext.lines().readContextCSV() }

fun List<String>.readContextCSV(diversity: Double = 1.0) =
drop(1).map { it.split(", ") }.associate {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ data class CEADist(val allProbs: Map<ContextEdit, Int>) {
val P_insert = allProbs.filter { it.key.type == EditType.INS }
val P_delSubOnCtx = P_delSub.keys.groupBy { it.context }
val P_insertOnCtx = P_insert.keys.groupBy { it.context }
val subLeft: Map<String, Set<String>> = allProbs.keys.filter { it.type == EditType.SUB }
.groupBy { it.context.left }.mapValues { it.value.map { it.newMid }.toSet() }
val insLeft: Map<String, Set<String>> = allProbs.keys.filter { it.type == EditType.INS }
.groupBy { it.context.left }.mapValues { it.value.map { it.newMid }.toSet() }
}

fun CFG.contextualRepair(broken: List<String>): Sequence<List<String>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ class BarHillelTest {
val toRepair = origStr.tokenizeByWhitespace()
val maxLevDist = 3
val levBall = makeLevFSA(toRepair, maxLevDist, gram.terminals)
println("Total transitions in FSA: ${levBall.Q.size}")
// throw Exception("")
// println(levBall.toDot())
// throw Exception("")
val intGram = gram.intersectLevFSA(levBall)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,14 @@ class ProbabilisticLBH {

val s2pg = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
// val origStr = "NAME = ( NAME . NAME ( NAME NEWLINE"
val origStr = invalidPythonStatements.lines().first() + " NEWLINE"
// invalidPythonStatements.lines().drop(1).forEach {
// val origStr = invalidPythonStatements.lines().first() + " NEWLINE"
invalidPythonStatements.lines().drop(1).forEach {
val clock = TimeSource.Monotonic.markNow()
// val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
// val origStr = "$it NEWLINE"
val origStr = "$it NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val levDist = 2
println("Top terms: ${topTerms.joinToString(", ")}")
val levBall = makeLevFSA(toRepair, levDist, topTerms)
val levBall = makeLevFSA(toRepair, levDist, topTerms, ceaDist = contextCSV)
println("Total transitions in FSA: ${levBall.Q.size}")
println("Prompt: $origStr")
println("Alphabet: ${levBall.alphabet}")
Expand All @@ -137,7 +136,7 @@ class ProbabilisticLBH {
assertTrue(levBall.recognizes(it))
}.toList()
.also { println("TOTAL LBH REPAIRS (${clock.elapsedNow()}): ${it.size}\n\n") }
// }
}
}

fun CFG.getS2PNT(string: String) =
Expand Down

0 comments on commit f68bcc2

Please sign in to comment.