Skip to content

Commit

Permalink
speed up filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Feb 25, 2024
1 parent 2575a0f commit 5863d0d
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ fun Array<DoubleArray>.toDoubleMatrix() = DoubleMatrix(size, this[0].size) { i,

fun kroneckerDelta(i: Int, j: Int) = if (i == j) 1.0 else 0.0

fun hashPair(i1: Int, i2: Int): Int = i1 * 31 + i2

const val DEFAULT_FEATURE_LEN = 20
fun String.vectorize(len: Int = DEFAULT_FEATURE_LEN) =
Random(hashCode()).let { randomVector(len) { it.nextDouble() } }
Expand Down
13 changes: 7 additions & 6 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
package ai.hypergraph.kaliningraph.automata

import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.graphs.*
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.tokenizeByWhitespace
import ai.hypergraph.kaliningraph.types.*
import kotlin.math.*

typealias Arc = Π3A<Σᐩ>
typealias TSA = Set<Arc>
fun Arc.pretty() = "$π1 -<$π2>-> $π3"
fun Σᐩ.coords(): Pair<Int, Int> =
(length / 2 - 1).let { substring(2, it + 2).toInt() to substring(it + 3).toInt() }
typealias STC = Triple<Σᐩ, Int, Int>
typealias STC = Triple<Int, Int, Int>
fun STC.coords() = π2 to π3

open class FSA(open val Q: TSA, open val init: Set<Σᐩ>, open val final: Set<Σᐩ>) {
open val alphabet by lazy { Q.map { it.π2 }.toSet() }
val isNominalizable by lazy { alphabet.any { it.startsWith("[!=]") } }
val nominalForm: NOM by lazy { nominalize() }
val states by lazy { Q.states }
val APSP: Map<Pair<Σᐩ, Σᐩ>, Int> by lazy {
val stateLst by lazy { states.toList() }
val stateMap by lazy { states.toList().withIndex().associate { it.value to it.index } }
val APSP: Map<Int, Int> by lazy {
graph.APSP.map { (k, v) ->
Pair(Pair(k.first.label, k.second.label), v)
Pair(hashPair(stateMap[k.first.label]!!, stateMap[k.second.label]!!), v)
}.toMap()
}

val transit: Map<Σᐩ, List<Pair<Σᐩ, Σᐩ>>> by lazy {
Q.groupBy { it.π1 }.mapValues { (_, v) -> v.map { it.π2 to it.π3 } }
}

val stateCoords: Sequence<STC> by lazy { states.map { it.coords().let { (i, j) -> Triple(it, i, j) } }.asSequence() }
val stateCoords: Sequence<STC> by lazy { states.map { it.coords().let { (i, j) -> Triple(stateMap[it]!!, i, j) } }.asSequence() }

val validTriples by lazy { stateCoords.let { it * it * it }.filter { it.isValidStateTriple() }.toList() }

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.hashPair
import ai.hypergraph.kaliningraph.repair.MAX_TOKENS
import ai.hypergraph.kaliningraph.types.*
import ai.hypergraph.kaliningraph.types.times
import kotlin.math.absoluteValue
import kotlin.math.*
import kotlin.time.TimeSource

/**
Expand All @@ -28,7 +29,6 @@ fun CFG.barHillelRepair(prompt: List<Σᐩ>, distance: Int) =
// https://browse.arxiv.org/pdf/2209.06809.pdf#page=5
private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG {
var clock = TimeSource.Monotonic.markNow()
val lengthBoundsCache = lengthBounds
val nts = mutableSetOf("START")
fun Σᐩ.isSyntheticNT() =
first() == '[' && last() == ']' && count { it == '~' } == 2
Expand All @@ -50,11 +50,15 @@ private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG {

// For each production A → BC in P, for every p, q, r ∈ Q,
// we have the production [p,A,r] → [p,B,q] [q,C,r] in P′.
val validTriples =
fsa.stateCoords.let { it * it * it }.filter { it.isValidStateTriple() }.toList()
val ntLst = nonterminals.toList()
val ntMap = ntLst.withIndex().associate { (i, s) -> s to i }
val prods: Set<IProduction> = nonterminalProductions
.map { (a, b) -> ntMap[a]!! to b.map { ntMap[it]!! } }.toSet()
val lengthBoundsCache = lengthBounds.let { lb -> nonterminals.map { lb[it]!! } }
val validTriples: List<Triple<STC, STC, STC>> = fsa.validTriples

val binaryProds =
nonterminalProductions.map {
prods.map {
// if (i % 100 == 0) println("Finished ${i}/${nonterminalProductions.size} productions")
val (A, B, C) = it.π1 to it.π2[0] to it.π2[1]
validTriples
Expand All @@ -64,7 +68,7 @@ private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG {
.filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap) }
.map { (a, b, c) ->
val (p, q, r) = a.π1 to b.π1 to c.π1
"[$p~$A~$r]".also { nts.add(it) } to listOf("[$p~$B~$q]", "[$q~$C~$r]")
"[$p~${ntLst[A]}~$r]".also { nts.add(it) } to listOf("[$p~${ntLst[B]}~$q]", "[$q~${ntLst[C]}~$r]")
}.toList()
}.flatten().filterRHSInNTS()

Expand Down Expand Up @@ -244,48 +248,35 @@ fun Π3A<STC>.isValidStateTriple(): Boolean {
// && obeys(second, third, nts.third)
//}

fun Π3A<STC>.obeysLevenshteinParikhBounds(nts: Triple<Σᐩ, Σᐩ, Σᐩ>, fsa: FSA, parikhMap: ParikhMap): Boolean {
fun obeys(a: STC, b: STC, nt: Σᐩ): Bln {
val sl =
fsa.levString.size <= a.second || // Part of the LA that handles extra
fsa.levString.size <= b.second // terminals at the end of the string

if (sl) return true
val margin = (b.third - a.third).absoluteValue
val length = (b.second - a.second)
val range = (length - margin).coerceAtLeast(1)..(length + margin)
val pb = parikhMap.parikhBounds(nt, range)
val pv = fsa.parikhVector(a.second, b.second)
return pb.admits(pv, margin)
}
private fun FSA.obeys(a: STC, b: STC, nt: Int, parikhMap: ParikhMap): Bln {
val sl = levString.size <= max(a.second, b.second) // Part of the LA that handles extra

return obeys(first, third, nts.first)
&& obeys(first, second, nts.second)
&& obeys(second, third, nts.third)
if (sl) return true
val margin = (b.third - a.third).absoluteValue
val length = (b.second - a.second)
val range = (length - margin).coerceAtLeast(1)..(length + margin)
val pb = parikhMap.parikhBounds(nt, range)
val pv = parikhVector(a.second, b.second)
return pb.admits(pv, margin)
}

fun Π3A<STC>.isCompatibleWith(nts: Triple<Σᐩ, Σᐩ, Σᐩ>, fsa: FSA, lengthBounds: Map<Σᐩ, IntRange>): Boolean {
fun lengthBounds(nt: Σᐩ): IntRange =
(lengthBounds[nt] ?: -9999..-9990)
// Okay if we overapproximate the length bounds a bit
// .let { (it.first - fudge)..(it.last + fudge) }

fun manhattanDistance(first: Pair<Int, Int>, second: Pair<Int, Int>): Int =
(second.second - first.second).absoluteValue + (second.first - first.first).absoluteValue
fun Π3A<STC>.obeysLevenshteinParikhBounds(nts: Triple<Int, Int, Int>, fsa: FSA, parikhMap: ParikhMap): Boolean =
fsa.obeys(first, third, nts.first, parikhMap)
&& fsa.obeys(first, second, nts.second, parikhMap)
&& fsa.obeys(second, third, nts.third, parikhMap)

// Range of the shortest path to the longest path, i.e., Manhattan distance
fun SPLP(a: STC, b: STC) =
(fsa.APSP[a.π1 to b.π1] ?: Int.MAX_VALUE)..
manhattanDistance(a.coords(), b.coords())
private fun manhattanDistance(first: Pair<Int, Int>, second: Pair<Int, Int>): Int =
(second.second - first.second).absoluteValue + (second.first - first.first).absoluteValue

fun IntRange.overlaps(other: IntRange) =
(other.first in first..last) || (other.last in first..last)
// Range of the shortest path to the longest path, i.e., Manhattan distance
private fun FSA.SPLP(a: STC, b: STC) =
(APSP[hashPair(a.π1, b.π1)] ?: Int.MAX_VALUE)..
manhattanDistance(a.coords(), b.coords())

// "[$p,$A,$r] -> [$p,$B,$q] [$q,$C,$r]"
fun isCompatible() =
lengthBounds(nts.first).overlaps(SPLP(first, third))
&& lengthBounds(nts.second).overlaps(SPLP(first, second))
&& lengthBounds(nts.third).overlaps(SPLP(second, third))
private fun IntRange.overlaps(other: IntRange) =
(other.first in first..last) || (other.last in first..last)

return isCompatible()
}
fun Π3A<STC>.isCompatibleWith(nts: Triple<Int, Int, Int>, fsa: FSA, lengthBounds: List<IntRange>): Boolean =
lengthBounds[nts.first].overlaps(fsa.SPLP(first, third))
&& lengthBounds[nts.second].overlaps(fsa.SPLP(first, second))
&& lengthBounds[nts.third].overlaps(fsa.SPLP(second, third))
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import kotlin.time.*

typealias Σᐩ = String
typealias Production = Π2<Σᐩ, List<Σᐩ>>
typealias IProduction = Π2<Int, List<Int>>
// TODO: make this immutable
typealias CFG = Set<Production>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ParikhMap(val cfg: CFG, val size: Int) {
private val lengthBounds: MutableMap<Int, Set<Σᐩ>> = mutableMapOf()
private val parikhMap: MutableMap<Int, ParikhBoundsMap> = mutableMapOf()
val parikhRangeMap: MutableMap<IntRange, ParikhBoundsMap> = mutableMapOf()
val ntIdx = cfg.nonterminals.toList()

companion object {
fun genRanges(delta: Int = 2 * MAX_RADIUS + 1, n: Int = MAX_TOKENS) =
Expand Down Expand Up @@ -83,6 +84,7 @@ class ParikhMap(val cfg: CFG, val size: Int) {
}
}

fun parikhBounds(nt: Int, range: IntRange): ParikhBounds = parikhBounds(ntIdx[nt], range)
fun parikhBounds(nt: Σᐩ, range: IntRange): ParikhBounds = parikhRangeMap[range]?.get(nt) ?: emptyMap()
fun parikhBounds(nt: Σᐩ, size: Int): ParikhBounds? = parikhMap[size]?.get(nt)
// parikhMap.also { println("Keys (${nt}): " + it.keys.size + ", ${it[size]?.get(nt)}") }[size]?.get(nt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {
// if (fsa.Q.size < 650) throw Exception("FSA size was out of bounds")
var clock = TimeSource.Monotonic.markNow()

val lengthBoundsCache = lengthBounds
val nts = ConcurrentSkipListSet(setOf("START"))

val initFinal =
Expand All @@ -170,31 +169,34 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {

// For each production A → BC in P, for every p, q, r ∈ Q,
// we have the production [p,A,r] → [p,B,q] [q,C,r] in P′.
val prods: Set<Production> = nonterminalProductions
var i = 0
val ntLst = nonterminals.toList()
val ntMap = ntLst.mapIndexed { i, s -> s to i }.toMap()
val prods: Set<IProduction> = nonterminalProductions
.map { (a, b) -> ntMap[a]!! to b.map { ntMap[it]!! } }.toSet()
val lengthBoundsCache = lengthBounds.let { lb -> nonterminals.map { lb[it]!! } }
val validTriples: List<Triple<STC, STC, STC>> = fsa.validTriples

val elimCounter = AtomicInteger(0)
val counter = AtomicInteger(0)
val lpClock = TimeSource.Monotonic.markNow()
val binaryProds =
prods.parallelStream().flatMap {
// if (i++ % 100 == 0) println("Finished $i/${nonterminalProductions.size} productions")
if (BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout: ${nts.size} nts")
val (A, B, C) = it.π1 to it.π2[0] to it.π2[1]
validTriples.stream()
// CFG ∩ FSA - in general we are not allowed to do this, but it works
// because we assume a Levenshtein FSA, which is monotone and acyclic.
.filter { it.isCompatibleWith(A to B to C, fsa, lengthBoundsCache).also { if (it) elimCounter.incrementAndGet() } }
.filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap).also { if (it) elimCounter.incrementAndGet() } }
.filter { it.isCompatibleWith(A to B to C, fsa, lengthBoundsCache).also { if (!it) elimCounter.incrementAndGet() } }
.filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap).also { if (!it) elimCounter.incrementAndGet() } }
.map { (a, b, c) ->
if (MAX_PRODS < counter.incrementAndGet())
throw Exception("∩-grammar has too many productions! (>$MAX_PRODS)")
val (p, q, r) = a.π1 to b.π1 to c.π1
"[$p~$A~$r]".also { nts.add(it) } to listOf("[$p~$B~$q]", "[$q~$C~$r]")
val (p, q, r) = fsa.stateLst[a.π1] to fsa.stateLst[b.π1] to fsa.stateLst[c.π1]
"[$p~${ntLst[A]}~$r]".also { nts.add(it) } to listOf("[$p~${ntLst[B]}~$q]", "[$q~${ntLst[C]}~$r]")
}
}.toList()

println("LP constraints eliminated $elimCounter productions...")
println("Levenshtein-Parikh constraints eliminated $elimCounter productions in ${lpClock.elapsedNow()}")

fun Σᐩ.isSyntheticNT() =
first() == '[' && length > 1 // && last() == ']' && count { it == '~' } == 2
Expand All @@ -214,7 +216,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {
fun CFG.jvmPostProcess(clock: TimeSource.Monotonic.ValueTimeMark) =
jvmDropVestigialProductions(clock)
.jvmElimVarUnitProds()
.also { println("Reduced ∩-grammar from $size to ${it.size} useful productions in ${clock.elapsedNow()}") }
.also { println("Normalization eliminated ${size - it.size} productions in ${clock.elapsedNow()}") }
.freeze()

tailrec fun CFG.jvmElimVarUnitProds(
Expand Down

0 comments on commit 5863d0d

Please sign in to comment.