Skip to content

Commit

Permalink
prototype SortValiant, a new lightweight solver
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Sep 18, 2023
1 parent 95877b6 commit f2fbab4
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 57 deletions.
19 changes: 10 additions & 9 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import org.jetbrains.kotlin.gradle.targets.js.nodejs.*
plugins {
signing
`maven-publish`
kotlin("multiplatform") version "1.9.10"
kotlin("multiplatform") version "1.9.20-Beta"
// kotlin("jupyter.api") version "0.11.0-225"
id("com.github.ben-manes.versions") version "0.47.0"
id("com.github.ben-manes.versions") version "0.48.0"
id("io.github.gradle-nexus.publish-plugin") version "2.0.0-rc-1"
}

Expand Down Expand Up @@ -137,15 +137,15 @@ kotlin {

implementation("org.jetbrains.kotlinx:kotlinx-html-jvm:$kotlinxVersion") // TODO: why is this necessary?

implementation("org.jetbrains.lets-plot:platf-awt-jvm:4.0.0")
implementation("org.jetbrains.lets-plot:lets-plot-kotlin-jvm:4.4.2")
implementation("org.jetbrains.lets-plot:platf-awt-jvm:4.0.1")
implementation("org.jetbrains.lets-plot:lets-plot-kotlin-jvm:4.4.3")

// https://arxiv.org/pdf/1908.10693.pdf
// implementation("com.datadoghq:sketches-java:0.7.0")

// Cache PMF/CDF lookups for common queries

implementation("org.apache.datasketches:datasketches-java:4.1.0")
implementation("org.apache.datasketches:datasketches-java:4.2.0")

// implementation("com.github.analog-garage:dimple:master-SNAPSHOT")

Expand Down Expand Up @@ -181,7 +181,7 @@ kotlin {

implementation("junit:junit:4.13.2")
implementation("org.jetbrains:annotations:24.0.1")
implementation("org.slf4j:slf4j-simple:2.0.7")
implementation("org.slf4j:slf4j-simple:2.0.9")

// http://www.ti.inf.uni-due.de/fileadmin/public/tools/grez/grez-manual.pdf
// implementation(files("$projectDir/libs/grez.jar"))
Expand All @@ -204,8 +204,9 @@ kotlin {
implementation("org.apache.tinkerpop:gremlin-core:$tinkerpopVersion")
implementation("org.apache.tinkerpop:tinkergraph-gremlin:$tinkerpopVersion")
implementation("info.debatty:java-string-similarity:2.0.0")
implementation("org.eclipse.collections:eclipse-collections-api:12.0.0.M2")
implementation("org.eclipse.collections:eclipse-collections:12.0.0.M2")
val eccVersion = "12.0.0.M3"
implementation("org.eclipse.collections:eclipse-collections-api:$eccVersion")
implementation("org.eclipse.collections:eclipse-collections:$eccVersion")

implementation(kotlin("scripting-jsr223"))
}
Expand All @@ -216,7 +217,7 @@ kotlin {
implementation(kotlin("test"))
implementation(kotlin("test-common"))
implementation(kotlin("test-annotations-common"))
implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.4.0")
implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.4.1")
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/ECA.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ fun <A> PKernel<A>.nullity(): FKernel<Int> =
2 != null).compareTo(false) to
3 != null).compareTo(false)

val ecaAlgebra = kernelAlgebra<Boolean>()
fun initializeECA(len: Int, cc: (Int) -> Boolean = { true }) =
val ecaAlgebra = kernelAlgebra<𝔹>()
fun initializeECA(len: Int, cc: (Int) -> 𝔹 = { true }) =
FreeMatrix(ecaAlgebra, len, 1) { r, c -> null to cc(r) to null }

// Create a tridiagonal (Toeplitz) matrix
Expand All @@ -33,8 +33,8 @@ fun <A> KernelMatrix<A>.genMat(algebra: Ring<PKernel<A>> = kernelAlgebra<A>()):
else null to null to null
}

fun BooleanArray.toECA() = initializeECA(size) { this[it] }
fun BooleanArray.evolve(steps: Int = 1): BooleanArray =
fun 𝔹ⁿ.toECA() = initializeECA(size) { this[it] }
fun 𝔹ⁿ.evolve(steps: Int = 1): 𝔹ⁿ =
toECA().evolve(steps = steps, rule = { (π2 && !π1) ||2 xor π3) }).data.map { it!!.second!! }.toBooleanArray()

tailrec fun <A> KernelMatrix<A>.evolve(
Expand All @@ -45,8 +45,8 @@ tailrec fun <A> KernelMatrix<A>.evolve(
if (steps == 0) this
else (circulantMatrix * this).nonlinearity(rule).evolve(rule, circulantMatrix, steps - 1)

fun FreeMatrix<PKernel<Boolean>>.str() = transpose.map { if (it?.π2 == true) "1" else " " }.toString()
fun FreeMatrix<PKernel<Boolean>>.print() = println(str())
fun FreeMatrix<PKernel<𝔹>>.str() = transpose.map { if (it?.π2 == true) "1" else " " }.toString()
fun FreeMatrix<PKernel<𝔹>>.print() = println(str())

fun <A> KernelMatrix<A>.nonlinearity(rule: FKernel<A>.() -> A): KernelMatrix<A> =
FreeMatrix(numRows, 1) { r, c -> null to (this[r, c] as FKernel<A>).rule() to null }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
@file:Suppress("NonAsciiCharacters")
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.graphs.LabeledGraph
Expand All @@ -6,7 +7,6 @@ import ai.hypergraph.kaliningraph.types.*
import kotlin.jvm.JvmName
import kotlin.time.*

@Suppress("NonAsciiCharacters")
typealias Σᐩ = String
typealias Production = Π2<Σᐩ, List<Σᐩ>>
// TODO: make this immutable
Expand Down Expand Up @@ -165,7 +165,7 @@ private fun Σᐩ.isTreelikeNonterminalIn(
cfg: CFG,
reachables: Set<Σᐩ> = cfg.reachableSymbols(this) - this,
nonTreeLike: Set<Σᐩ> = setOf(this)
): Boolean = when {
): 𝔹 = when {
"ε" in this -> true
(reachables intersect nonTreeLike).isNotEmpty() -> false
else -> reachables.all { it in cfg.terminals ||
Expand All @@ -180,7 +180,7 @@ class JoinMap(val CFG: CFG) {
.associateWith { subsets -> subsets.let { (l, r) -> join(l, r) } }
.also { println("Precomputed join map has ${it.size} entries.") }.toMutableMap()

fun join(l: Set<Σᐩ>, r: Set<Σᐩ>, tryCache: Boolean = false): Set<Π3A<Σᐩ>> =
fun join(l: Set<Σᐩ>, r: Set<Σᐩ>, tryCache: 𝔹 = false): Set<Π3A<Σᐩ>> =
if (tryCache) precomputedJoins[l to r] ?: join(l, r, false).also { precomputedJoins[l to r] = it }
else (l * r).flatMap { (l, r) -> CFG.bimap[listOf(l, r)].map { Triple(it, l, r) } }.toSet()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
@file:Suppress("NonAsciiCharacters")

package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.sampling.*
import ai.hypergraph.kaliningraph.splitProd
import ai.hypergraph.kaliningraph.tensor.*
import ai.hypergraph.kaliningraph.types.*


// SetValiant interface
//=====================================================================================
fun Σᐩ.matches(cfg: Σᐩ): Boolean = matches(cfg.validate().parseCFG())
fun Σᐩ.matches(CFG: CFG): Boolean = CFG.isValid(tokenizeByWhitespace())
fun Σᐩ.matches(CJL: CJL): Boolean = CJL.cfgs.all { matches(it) }
fun List<Σᐩ>.matches(CFG: CFG): Boolean = CFG.isValid(this)
fun Σᐩ.matches(cfg: Σᐩ): 𝔹 = matches(cfg.validate().parseCFG())
fun Σᐩ.matches(CFG: CFG): 𝔹 = CFG.isValid(tokenizeByWhitespace())
fun Σᐩ.matches(CJL: CJL): 𝔹 = CJL.cfgs.all { matches(it) }
fun List<Σᐩ>.matches(CFG: CFG): 𝔹 = CFG.isValid(this)
fun Σᐩ.parse(s: Σᐩ): Tree? = parseCFG().parse(s)
fun CFG.parse(s: Σᐩ): Tree? =
try { parseForest(s).firstOrNull { it.root == START_SYMBOL }?.denormalize() }
Expand All @@ -31,8 +34,8 @@ private fun List<Σᐩ>.pad3(): List<Σᐩ> =
else if (size == 1) listOf("ε", first(), "ε")
else this

fun CFG.isValid(str: Σᐩ): Boolean = isValid(str.tokenizeByWhitespace())
fun CFG.isValid(str: List<Σᐩ>): Boolean =
fun CFG.isValid(str: Σᐩ): 𝔹 = isValid(str.tokenizeByWhitespace())
fun CFG.isValid(str: List<Σᐩ>): 𝔹 =
initialUTBMatrix(str.pad3()).seekFixpoint().diagonals
// .also { it.forEachIndexed { r, d -> d.forEachIndexed { i, it -> println("$r, $i: ${toNTSet(it)}") } } }
.last().first()//.also { println("Last: ${it.joinToString(",") {if (it) "1" else "0"}}") }
Expand Down Expand Up @@ -100,12 +103,12 @@ fun CFG.treeJoin(left: Forest, right: Forest): Forest =
fun CFG.setJoin(left: Set<Σᐩ>, right: Set<Σᐩ>): Set<Σᐩ> =
(left * right).flatMap { bimap[it.toList()] }.toSet()

fun CFG.toBitVec(nts: Set<Σᐩ>): BooleanArray =
fun CFG.toBitVec(nts: Set<Σᐩ>): 𝔹ⁿ =
if (1 < nts.size) nonterminals.map { it in nts }.toBooleanArray()
else BooleanArray(nonterminals.size) { false }
.also { if (1 == nts.size) it[bindex[nts.first()]] = true }

fun fastJoin(/**[vindex]*/vidx: Array<IntArray>, left: BooleanArray, right: BooleanArray): BooleanArray {
fun fastJoin(/**[vindex]*/vidx: Array<IntArray>, left: 𝔹ⁿ, right: 𝔹ⁿ): 𝔹ⁿ {
if (left.isEmpty() || right.isEmpty()) return booleanArrayOf()

val result = BooleanArray(vidx.size)
Expand All @@ -125,18 +128,18 @@ fun fastJoin(/**[vindex]*/vidx: Array<IntArray>, left: BooleanArray, right: Bool
// if (left.isEmpty() || right.isEmpty()) booleanArrayOf()
// else vindex.map { it.any { (B, C) -> left[B] and right[C] } }.toBooleanArray()

fun CFG.join(left: BooleanArray, right: BooleanArray): BooleanArray = fastJoin(vindex, left, right)
fun CFG.join(left: 𝔹ⁿ, right: 𝔹ⁿ): 𝔹ⁿ = fastJoin(vindex, left, right)

fun maybeJoin(vindexFast: Array<IntArray>, left: BooleanArray?, right: BooleanArray?): BooleanArray? =
fun maybeJoin(vindexFast: Array<IntArray>, left: 𝔹ⁿ?, right: 𝔹ⁿ?): 𝔹ⁿ? =
if (left == null || right == null) null else fastJoin(vindexFast, left, right)

fun maybeUnion(left: BooleanArray?, right: BooleanArray?): BooleanArray? =
fun maybeUnion(left: 𝔹ⁿ?, right: 𝔹ⁿ?): 𝔹ⁿ? =
if (left == null || right == null) { left ?: right }
else if (left.isEmpty() && right.isNotEmpty()) right
else if (left.isNotEmpty() && right.isEmpty()) left
else union(left, right)

fun union(left: BooleanArray, right: BooleanArray): BooleanArray {
fun union(left: 𝔹ⁿ, right: 𝔹ⁿ): 𝔹ⁿ {
val result = BooleanArray(left.size)
for (i in left.indices) {
result[i] = left[i]
Expand All @@ -146,7 +149,7 @@ fun union(left: BooleanArray, right: BooleanArray): BooleanArray {
return result
}

val CFG.bitwiseAlgebra: Ring<BooleanArray> by cache {
val CFG.bitwiseAlgebra: Ring<𝔹ⁿ> by cache {
vindex.let {
Ring.of(
nil = BooleanArray(nonterminals.size) { false },
Expand All @@ -157,7 +160,7 @@ val CFG.bitwiseAlgebra: Ring<BooleanArray> by cache {
}

// Like bitwiseAlgebra, but with nullable bitvector literals for free variables
val CFG.satLitAlgebra: Ring<BooleanArray?> by cache {
val CFG.satLitAlgebra: Ring<𝔹ⁿ?> by cache {
vindex.let {
Ring.of(
nil = BooleanArray(nonterminals.size) { false },
Expand All @@ -167,26 +170,26 @@ val CFG.satLitAlgebra: Ring<BooleanArray?> by cache {
}
}

fun CFG.toNTSet(nts: BooleanArray): Set<Σᐩ> =
fun CFG.toNTSet(nts: 𝔹ⁿ): Set<Σᐩ> =
nts.mapIndexed { i, it -> if (it) bindex[i] else null }.filterNotNull().toSet()

fun BooleanArray.decodeWith(cfg: CFG): Set<Σᐩ> =
fun 𝔹ⁿ.decodeWith(cfg: CFG): Set<Σᐩ> =
mapIndexed { i, it -> if (it) cfg.bindex[i] else null }.filterNotNull().toSet()

fun CFG.toBooleanArray(nts: Set<Σᐩ>): BooleanArray =
fun CFG.toBooleanArray(nts: Set<Σᐩ>): 𝔹ⁿ =
BooleanArray(nonterminals.size) { i -> bindex[i] in nts }

//=====================================================================================

val HOLE_MARKER = "_"
fun Σᐩ.containsHole(): Boolean = HOLE_MARKER in this
fun Σᐩ.containsHole(): 𝔹 = HOLE_MARKER in this
fun Σᐩ.isHoleTokenIn(cfg: CFG) = containsHole() || isNonterminalStubIn(cfg)
//val ntRegex = Regex("<[^\\s>]*>")
fun Σᐩ.isNonterminalStub() = isNotEmpty() && first() == '<' && last() == '>'
fun Σᐩ.isNonterminalStubInNTs(nts: Set<Σᐩ>): Boolean = isNonterminalStub() && drop(1).dropLast(1) in nts
fun Σᐩ.isNonterminalStubIn(cfg: CFG): Boolean = isNonterminalStub() && drop(1).dropLast(1) in cfg.nonterminals
fun Σᐩ.isNonterminalStubIn(CJL: CJL): Boolean = CJL.cfgs.map { isNonterminalStubIn(it) }.all { it }
fun String.containsNonterminal(): Boolean = Regex("<[^\\s>]*>") in this
fun Σᐩ.isNonterminalStubInNTs(nts: Set<Σᐩ>): 𝔹 = isNonterminalStub() && drop(1).dropLast(1) in nts
fun Σᐩ.isNonterminalStubIn(cfg: CFG): 𝔹 = isNonterminalStub() && drop(1).dropLast(1) in cfg.nonterminals
fun Σᐩ.isNonterminalStubIn(CJL: CJL): 𝔹 = CJL.cfgs.map { isNonterminalStubIn(it) }.all { it }
fun Σᐩ.containsNonterminal(): 𝔹 = Regex("<[^\\s>]*>") in this

// Converts tokens to UT matrix via constructor: σ_i = { A | (A -> w[i]) ∈ P }
fun CFG.initialMatrix(str: List<Σᐩ>): TreeMatrix =
Expand All @@ -201,11 +204,13 @@ fun CFG.initialUTBMatrix(
tokens: List<Σᐩ>,
allNTs: Set<Σᐩ> = nonterminals,
bmp: BiMap = bimap,
unitReach: Map<Σᐩ, Set<String>> = originalForm.unitReachability
): UTMatrix<BooleanArray> =
unitReach: Map<Σᐩ, Set<Σᐩ>> = originalForm.unitReachability
): UTMatrix<𝔹ⁿ> =
UTMatrix(
ts = tokens.map { it ->
bmp[listOf(it)].let { nts ->
// Check whether the token part of a string that contains a user-
// defined nonterminal stub that was in the original grammar
if (tokens.none { it.isNonterminalStubInNTs(allNTs) }) nts
// We use the original form because A -> B -> C can be normalized
// to A -> C, and we want B to be included in the equivalence class
Expand Down Expand Up @@ -244,8 +249,8 @@ private val freshNames: Sequence<Σᐩ> =
.filter { it != START_SYMBOL }

fun Σᐩ.parseCFG(
normalize: Boolean = true,
validate: Boolean = false
normalize: 𝔹 = true,
validate: 𝔹 = false
): CFG =
(if (validate) validate() else this).lines().filter { "->" in it }.map { line ->
val prod = line.splitProd()
Expand Down Expand Up @@ -281,7 +286,7 @@ fun Σᐩ.validate(
fun List<Σᐩ>.solve(
CFG: CFG,
fillers: Set<Σᐩ> = CFG.terminals - CFG.blocked,
takeMoreWhile: () -> Boolean = { true },
takeMoreWhile: () -> 𝔹 = { true },
): Sequence<Σᐩ> =
genCandidates(CFG, fillers)
// .also { println("Solving (Complexity: ${fillers.size.pow(count { it == "_" })}): ${joinToString(" ")}") }
Expand All @@ -295,7 +300,7 @@ fun List<Σᐩ>.genCandidates(CFG: CFG, fillers: Set<Σᐩ> = CFG.terminals): Se
}

// TODO: Compactify [en/de]coding: https://news.ycombinator.com/item?id=31442706#31442719
fun CFG.nonterminals(bitvec: List<Boolean>): Set<Σᐩ> =
fun CFG.nonterminals(bitvec: List<𝔹>): Set<Σᐩ> =
bitvec.mapIndexedNotNull { i, it -> if (it) bindex[i] else null }.toSet()
.apply { ifEmpty { throw Exception("Unable to reconstruct NTs from: $bitvec") } }

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
@file:Suppress("NonAsciiCharacters")

package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*

fun CFG.sortAll(s: Σᐩ): Set<Σᐩ> =
try { solveSortedFP(s.tokenizeByWhitespace())[START_SYMBOL]?.map { it.first }?.toSet() ?: setOf() }
catch (e: Exception) { setOf() }

fun CFG.solveSortedFP(
tokens: List<Σᐩ>,
utMatrix: UTMatrix<Sort> = initialUTSMatrix(tokens),
) = utMatrix.seekFixpoint().toFullMatrix()[0].last()

fun CFG.initialUTSMatrix(tokens: List<Σᐩ>, bmp: BiMap = bimap): UTMatrix<Sort> =
UTMatrix(
ts = tokens.map { token ->
(if (token == HOLE_MARKER)
unitReachability.values.flatten().toSet().map { root ->
bmp[root].filter { it.size == 1 }
.map { it.first() }.filter { it in terminals }
}.flatten().toSet()
else bmp[listOf(token)]).associateWith {
listOf(token to if (token == "ε") 0 else 1)
}
}.toTypedArray().also {
it.forEach { println(it.size) }
},
algebra = sortwiseAlgebra
)

// Maintains a sorted list of nonterminal roots and their leaves
val CFG.sortwiseAlgebra: Ring<Sort> by cache {
Ring.of(
nil = mapOf(),
plus = { x, y -> union(x, y) },
times = { x, y -> join(x, y) }
)
}

operator fun SRec.plus(s2: SRec): SRec =
first + s2.first to second + s2.second

fun CFG.join(s1: Sort, s2: Sort): Sort =
bimap.L2RHS.entries.associate { (k, v) ->
k to v.filter { it.size == 2 }.map { (a, b) ->
val left = s1[a]
val right = s2[b]
if (left != null || right != null) {
(left!!.toSet() * right!!.toSet())
.map { (q, r) -> q + r }
} else mutableListOf()
}.flatten()
}

fun union(s1: Sort, s2: Sort): Sort =
s1.mapValues { (k, v) ->
if (k in s2) { v }
else {
val (a, b) = v.iterator() to s2[k]!!.iterator()
val newList = mutableListOf<SRec>()
while (a.hasNext() || b.hasNext()) {
val toAdd =
if (!a.hasNext()) b.next()
else if (!b.hasNext()) a.next()
else {
val (a1, a2) = a.next()
val (b1, b2) = b.next()
if (a2 < b2) a1 to a2
else if (b2 < a2) b1 to b2
else a1 to a2
}

if (newList.last() != toAdd) newList.add(toAdd)
}

newList
}
}

typealias Sort = Map<Σᐩ, List<SRec>>
// Substring and some metric (e.g., number of blanks)
// TODO: Associate a more concrete semantics with second value,
// but for now just the number of terminals. For example,
// we could use perplexity of a Markov chain or the length
// of the longest common substring with the original string.
typealias SRec = Π2<Σᐩ, Int>
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
@file:Suppress("UNUSED_PARAMETER", "UNCHECKED_CAST")
@file:Suppress("UNUSED_PARAMETER", "UNCHECKED_CAST", "NonAsciiCharacters")
package ai.hypergraph.kaliningraph.types

import kotlin.jvm.JvmName

typealias 𝔹 = Boolean
typealias 𝔹ⁿ = BooleanArray

sealed class B<X, P : B<X, P>>(open val x: X? = null) {
val T: T<P> get() = T(this as P)
val F: F<P> get() = F(this as P)
Expand Down
Loading

0 comments on commit f2fbab4

Please sign in to comment.