Skip to content

Commit

Permalink
use more sensible names
Browse files Browse the repository at this point in the history
naming is one of the three hardest problems in computer science
  • Loading branch information
breandan committed Sep 22, 2023
1 parent b5f26ac commit 00bda88
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ val CFG.terminalUnitProductions: Set<Production>
by cache { filter { it.RHS.size == 1 && it.RHS[0] !in nonterminals } }
val CFG.unitProductions: Set<Production> by cache { filter { it.RHS.size == 1 } }
val CFG.nonterminalProductions: Set<Production> by cache { filter { it !in terminalUnitProductions } }
val CFG.unitNonterminals: Set<Σᐩ> by cache { terminalUnitProductions.map { it.LHS }.toSet() }
val CFG.bimap: BiMap by cache { BiMap(this) }
// Maps nonterminal sets to their terminals, n.b., each terminal can be generated
// by multiple nonterminals, and each nonterminal can generate multiple terminals
Expand Down Expand Up @@ -78,7 +79,7 @@ val CFG.reachability by cache { mutableMapOf<Σᐩ, Set<Σᐩ>>() }
val CFG.unitReachability by cache {
symbols.associateWith { from ->
LabeledGraph {
unitProductions.map { it.LHS to it.RHS.first() }
unitProductions.map { it.LHS to it.RHS[0] }
// .filter { (a, b) -> nonterminals.containsAll(listOf(a, b)) }
.forEach { (a, b) -> a - b }
}.let {
Expand Down Expand Up @@ -215,6 +216,9 @@ class BiMap(cfg: CFG) {
.map { it.value.map { v -> v to it.key[0] to it.key[1] } }.flatten()
val X2WZ: Map<Σᐩ, List<Triple<Σᐩ, Σᐩ, Σᐩ>>> = TRIPL.groupBy { it.second }
.mapValues { it.value.map { it.first to it.second to it.third } }
val UNITS =
cfg.filter { it.RHS.size == 1 && it.RHS[0] !in cfg.nonterminals }
.groupBy({ it.LHS }, { it.RHS[0] }).mapValues { it.value.toSet() }
operator fun get(p: List<Σᐩ>): Set<Σᐩ> = R2LHS[p] ?: emptySet()
operator fun get(p: Σᐩ): Set<List<Σᐩ>> = L2RHS[p] ?: emptySet()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,19 @@ class Repair constructor(val orig: List<Σᐩ>, val edit: Edit, val result: List

fun matches(groundTruth: String): Boolean = resToStr() == groundTruth

// Computes a "fingerprint" of the repair to avoid redundant results
// Each fingerprint can be lazily expanded to a sequence of repairs
// formed by the Cartesian product of tokens at each change position
// e.g., "C + C" -> "1 + 2", "1 + 3", "2 + 1", "2 + 3", "3 + 1", "3 + 2"... etc.

override fun hashCode(): Int = result.hashCode()
override fun equals(other: Any?): Boolean =
if (other is Repair) result == other.result else false

fun elapsed(): String = (if (timeMS == -1L) "N/A" else "${timeMS / 1000.0}").take(4) + "s"
fun scoreStr(): String = "$score".take(5)

// TODO: Computes a "fingerprint" of the repair to avoid redundant results
// Each fingerprint can be lazily expanded to a sequence of repairs
// formed by the Cartesian product of tokens at each change position
// e.g., "C + C" -> "1 + 2", "1 + 3", "2 + 1", "2 + 3", "3 + 1", "3 + 2"... etc.

/**
* This can be used to generate a sequence of repairs with the same edit fingerprint but alternate
* tokens at each change location. This method may optionally be called on any Repair, but for the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,16 @@ import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*

// Returns all syntactically strings ordered by distance to withRespect
fun CFG.sortAll(s: Σᐩ, withRespectTo: Σᐩ): Set<Σᐩ> =
try { solveSortedFP(s.tokenizeByWhitespace(), withRespectTo.tokenizeByWhitespace())
?.sortedBy { it.second }
?.map { it.first.filterNot { "ε" in it }.joinToString(" ") }?.toSet() ?: setOf() }
fun CFG.sortAll(s: Σᐩ, metric: ChoiceMetric): Set<Σᐩ> =
try { solveSortedFP(s.tokenizeByWhitespace(), metric)
?.sortedBy { it.weight }
?.map { it.sanitize().joinToString(" ") }?.toSet() ?: setOf() }
catch (e: Exception) { e.printStackTrace(); setOf() }

fun CFG.solveSortedFP(
tokens: List<Σᐩ>,
withRespectTo: List<Σᐩ>,
utMatrix: UTMatrix<Sort> =
initialUTSMatrix(tokens,
sortwiseAlgebra(metric = {
levenshtein(it.first.filterNot { "ε" in it }, withRespectTo).toFloat()
})
),
metric: ChoiceMetric,
utMatrix: UTMatrix<Sort> = initialUTSMatrix(tokens, sortwiseAlgebra(metric)),
) = utMatrix.seekFixpoint().toFullMatrix()[0].last()[START_SYMBOL]

fun CFG.initialUTSMatrix(
Expand All @@ -30,66 +25,71 @@ fun CFG.initialUTSMatrix(
): UTMatrix<Sort> =
UTMatrix(
ts = tokens.map { token ->
(if (token == HOLE_MARKER)
unitReachability.values.flatten().toSet().filter { root ->
bimap[root].any { it.size == 1 && it.first() in terminals }
}.toSet()
else bimap[listOf(token)])
.associateWith {
if (token == HOLE_MARKER)
bimap[it].filter { it.size == 1 && it.first() in terminals && !it.first().isNonterminalStub() }
.map { it.first() }.map { listOf(it) to if ("ε" in token) 0f else 1f }.toSet()
else setOf(listOf(token) to if ("ε" in token) 0f else 1f)
(if (token != HOLE_MARKER) bimap[listOf(token)] else unitNonterminals)
.associateWith { nt ->
if (token != HOLE_MARKER) setOf(Choice(token))
else bimap.UNITS[nt]?.map { Choice(it) }?.toSet() ?: setOf()
}
}.toTypedArray(),
algebra = algebra
)

// Maintains a sorted list of nonterminal roots and their leaves
fun CFG.sortwiseAlgebra(metric: (SRec) -> Float): Ring<Sort> =
fun CFG.sortwiseAlgebra(metric: ChoiceMetric): Ring<Sort> =
Ring.of(
nil = mapOf(),
plus = { x, y -> union(x, y) },
times = { x, y -> join(x, y, metric) }
times = { x, y -> join(x, y, metric) },
)

operator fun SRec.plus(s2: SRec): SRec =1 + s2.π1) to (π2 + s2.π2)

const val MAX_CAPACITY = 100
// X ⊗ Z := { w | <x, z> ∈ X × Z, (w -> xz) ∈ P }
fun CFG.join(s1: Sort, s2: Sort, metric: (SRec) -> Float = { it.second }): Sort =
bimap.TRIPL.filter { (_, x, z) -> x in s1 && z in s2 }
fun CFG.join(X: Sort, Z: Sort, metric: ChoiceMetric = { it.weight }): Sort =
bimap.TRIPL.filter { (_, x, z) -> x in X && z in Z }
.map { (w, x, z) ->
((s1[x] ?: setOf()) * (s2[z] ?: setOf()))
((X[x] ?: setOf()) * (Z[z] ?: setOf()))
.map { (q, r) -> w to (q + r) }
}.flatten().groupingBy { it.first }
.aggregate<Pair<Σᐩ, SRec>, Σᐩ, MutableList<SRec>> { _, acc, it, _ ->
val toInsert = it.second.let { it.first to metric(it) }
val list = (acc ?: mutableListOf())
val idx = list.binarySearch(toInsert,
compareBy<SRec> { it.second }
// .thenBy { it.first.hashCode() }
)
list.add(if (idx < 0) -idx - 1 else idx, toInsert)
// if (idx < 0) list.add(-idx - 1, toInsert)
.aggregate<Pair<Σᐩ, Choice>, Σᐩ, MutableList<Choice>> { _, acc, it, _ ->
val choice = Choice(it.second.tokens, metric(it.second))
val list = acc ?: mutableListOf()
val idx = list.binarySearch(choice, Choice.comparator)
if (idx < 0) list.add(-idx - 1, choice) // Only if not already present
list.apply { if (MAX_CAPACITY < size) removeLast() }
}
.mapValues { it.value.toSet() }

fun union(s1: Sort, s2: Sort): Sort =
(s1.keys + s2.keys).associateWith { k ->
((s1[k] ?: setOf()) + (s2[k] ?: setOf()))
// .sortedBy { it.second }.take(100).toSet()
}
fun union(l: Sort, r: Sort): Sort =
(l.keys + r.keys).associateWith { k -> (l[k] ?: setOf()) + (r[k] ?: setOf()) }

// Map of root to the possible sets of leaves
// This is like a tree where we do not store the internal nodes
// One root can represent many strings, but we only care about unique leaf sequences
// Map of root to the possible sets of token sequences it can produce in context
// This is identical to a forest minus internal branches, just roots and leaves
// Each root represents many strings, we only care about unique leaf sequences
// Maintains a sort ordering based on some metric of the most likely derivations
typealias Sort = Map<Σᐩ, Set<SRec>>
typealias Sort = Map<Σᐩ, Set<Choice>>
typealias ChoiceMetric = (Choice) -> Float
// Substring and some metric (e.g., number of blanks)
// TODO: Associate a more concrete semantics with second value,
// but for now just the number of terminals. For example,
// we could use perplexity of a Markov chain or the length
// of the longest common substring with the original string.
typealias SRec = Π2<List<Σᐩ>, Float>
data class Choice(val tokens: List<Σᐩ>, val weight: Float): Comparable<Choice> {
constructor(token: Σᐩ): this(listOf(token), if ("ε" in token) 0f else 1f)

companion object {
val comparator: Comparator<Choice> =
compareBy<Choice> { it.weight }.thenBy { it.tokens.hashCode() }
}

override fun compareTo(other: Choice): Int = comparator.compare(this, other)

operator fun plus(other: Choice) =
Choice(tokens + other.tokens, weight + other.weight)

fun sanitize() = tokens.filterNot { "ε" in it }
}

// Returns a metric measuring Levenshtein distance w.r.t. some reference string
fun levMetric(withRespectTo: Σᐩ): ChoiceMetric =
withRespectTo.tokenizeByWhitespace()
.let { wrt -> { levenshtein(it.sanitize(), wrt).toFloat() } }
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,14 @@ class SetValiantTest {
VO -> = | < | `||` | `&&`
I -> 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
B -> true | false
""".trimIndent().parseCFG()
""".trimIndent().parseCFG().noNonterminalStubs

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.SetValiantTest.testOCaml"
*/
@Test
fun testOCaml() {
val expr = "1 + <I> + 2"
val expr = "1 + 2 + 3"
val tree = ocamlCFG.parse(expr)!!
println(tree.prettyPrint())
val leaves = tree.contents()
Expand All @@ -363,7 +363,7 @@ class SetValiantTest {
val holExpr = "_ _ _ _ _ _ _ _ _ _"

measureTime {
val solutions = ocamlCFG.sortAll(holExpr, withRespectTo = "( false curry )")
val solutions = ocamlCFG.sortAll(holExpr, levMetric("( false curry )"))
println("Found: ${solutions.size} unique solutions")
solutions.forEach { println(it); assertTrue("$it was invalid!") { ocamlCFG.isValid(it) } }
}.also { println("Finished in ${it.inWholeMilliseconds}ms.") }
Expand Down

0 comments on commit 00bda88

Please sign in to comment.