Skip to content

Commit

Permalink
parameterize the metric
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Sep 20, 2023
1 parent e4907c1 commit 1bf49ed
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ fun <T> levenshtein(o1: List<T>, o2: List<T>): Int {
return prev[o2.size]
}

fun multisetManhattanDistance(s1: Σᐩ, s2: Σᐩ): Int =
multisetManhattanDistance(s1.tokenizeByWhitespace().toList(), s2.tokenizeByWhitespace().toList())

fun <T> multisetManhattanDistance(q1: List<T>, q2: List<T>): Int {
val (s1, s2) = listOf(q1, q2).map { it.groupingBy { it }.eachCount() }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,83 +2,75 @@

package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.levenshtein
import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*

fun CFG.sortAll(s: Σᐩ): Set<Σᐩ> =
try { solveSortedFP(s.tokenizeByWhitespace())[START_SYMBOL]
// Returns all syntactically strings ordered by distance to withRespect
fun CFG.sortAll(s: Σᐩ, withRespectTo: Σᐩ): Set<Σᐩ> =
try { solveSortedFP(s.tokenizeByWhitespace(), withRespectTo)[START_SYMBOL]
?.map { it.first.tokenizeByWhitespace().filterNot { "ε" in it }.joinToString(" ") }?.toSet() ?: setOf() }
catch (e: Exception) { e.printStackTrace(); setOf() }

fun CFG.solveSortedFP(
tokens: List<Σᐩ>,
utMatrix: UTMatrix<Sort> = initialUTSMatrix(tokens),
withRespectTo: Σᐩ,
utMatrix: UTMatrix<Sort> =
initialUTSMatrix(tokens, sortwiseAlgebra(metric = { levenshtein(it.first, withRespectTo) })),
) = utMatrix.seekFixpoint().toFullMatrix()[0].last()

fun CFG.initialUTSMatrix(tokens: List<Σᐩ>, bmp: BiMap = bimap): UTMatrix<Sort> =
fun CFG.initialUTSMatrix(
tokens: List<Σᐩ>,
algebra: Ring<Sort>
): UTMatrix<Sort> =
UTMatrix(
ts = tokens.map { token ->
(if (token == HOLE_MARKER)
unitReachability.values.flatten().toSet().filter { root ->
bmp[root].any { it.size == 1 && it.first() in terminals }
bimap[root].any { it.size == 1 && it.first() in terminals }
}.toSet()
else bmp[listOf(token)])
else bimap[listOf(token)])
.associateWith {
if (token == HOLE_MARKER)
bmp[it].filter { it.size == 1 && it.first() in terminals && !it.first().isNonterminalStub() }
bimap[it].filter { it.size == 1 && it.first() in terminals && !it.first().isNonterminalStub() }
.map { it.first() }.toSet().map { it to if ("ε" in token) 0 else 1 }
else listOf(token to if ("ε" in token) 0 else 1)
}
}.toTypedArray()//.also { it.forEach { println("" + it.size + ":" + it) } }
,
algebra = sortwiseAlgebra
}.toTypedArray(),
algebra = algebra
)

// Maintains a sorted list of nonterminal roots and their leaves
val CFG.sortwiseAlgebra: Ring<Sort> by cache {
fun CFG.sortwiseAlgebra(metric: (SRec) -> Int): Ring<Sort> =
Ring.of(
nil = mapOf(),
plus = { x, y -> union(x, y) },
times = { x, y -> join(x, y) }
times = { x, y -> join(x, y, metric) }
)
}

operator fun SRec.plus(s2: SRec): SRec =
"$first ${s2.first}" to second + s2.second

// X ⊗ Z := { w | <x, z> ∈ X × Z, (w -> xz) ∈ P }
fun CFG.join(s1: Sort, s2: Sort): Sort =
fun CFG.join(s1: Sort, s2: Sort, metric: (SRec) -> Int = { it.second }): Sort =
(s1.keys * s2.keys).map { (x, z) ->
bimap[listOf(x, z)].also { it.size }.map { it to x to z }
bimap[listOf(x, z)].map { it to x to z }
}.flatten().map { (w, x, z) ->
((s1[x] ?: listOf()).toSet() * (s2[z] ?: listOf()).toSet())
.map { (q, r) ->
// println("Joining: $w to ${q.first} and ${r.first}")
w to (q + r)
}
.map { (q, r) -> w to (q + r) }
}.flatten().groupingBy { it.first }
.aggregate { _, acc, it, _ ->
// Maybe only propagate the top N according to metric to avoid blowup
(acc ?: listOf()) + it.second
val toInsert = it.second.let { it.first to metric(it) }
((acc ?: listOf()) + toInsert).sortedBy { it.second }.take(10)
}
// bimap.L2RHS.entries.mapNotNull { (k, v) ->
// val q = v.filter { it.size == 2 }.map { (a, b) ->
// val left = s1[a]
// val right = s2[b]
// if (left != null && right != null) {
// println("left: ${left.size}; right: ${right.size}")
// (left.toSet() * right.toSet())
// .map { (q, r) -> q + r }
// } else listOf()
// }.flatten().ifEmpty { null }
// if (q != null) k to q else null
// }.toMap()

fun union(s1: Sort, s2: Sort): Sort =
(s1.keys + s2.keys).associateWith { k ->
((s1[k] ?: listOf()) + (s2[k] ?: listOf())).sortedBy { it.second }
((s1[k] ?: listOf()) + (s2[k] ?: listOf())).sortedBy { it.second }.take(100)
}

// Mutable list that maintains a sorted order and has a fixed capacity.

// Map of root to the possible sets of leaves
// This is like a tree where we do not store the internal nodes
// The same root can have multiple derivations, but we only care about unique leaf sequences
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,9 @@ class SetValiantTest {
val leaves = tree.contents()
assertEquals(expr, leaves)

val holExpr = "1 + _ + _ + 1"
val holExpr = "_ _ _ _ _ _ _ _"

// val trees = ocamlCFG.parseAll(holExpr)
// println("Found: ${trees.size} unique trees")
// trees.map { it.contents() }.forEach { assertTrue("$it was invalid!") { ocamlCFG.isValid(it) } }

val solutions = ocamlCFG.sortAll(holExpr)
val solutions = ocamlCFG.sortAll(holExpr, withRespectTo = "( false curry )")
println("Found: ${solutions.size} unique solutions")
solutions.forEach { println(it); assertTrue("$it was invalid!") { ocamlCFG.isValid(it) } }
}
Expand Down

0 comments on commit 1bf49ed

Please sign in to comment.