Skip to content

Commit

Permalink
speed up PCFG sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Feb 24, 2024
1 parent 85f65f2 commit ac7f61e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*
import com.ionspin.kotlin.bignum.integer.*
import kotlin.jvm.JvmName
import kotlin.math.*
import kotlin.random.Random
import kotlin.time.measureTimedValue


// Indexes a set of PTrees by their roots
typealias PForest = Map<String, PTree> // ℙ₃
// Algebraic data type / polynomial functor for parse forests (ℙ₂)
Expand Down Expand Up @@ -59,7 +59,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
}

private val choice by lazy {
if (branches.isEmpty()) listOf(if ("ε" in root) "" else root)
if (branches.isEmpty()) listOf(epsStr)
else shuffledBranches.flatMap { (l, r) ->
(l.choose() * r.choose()).map { (a, b) ->
if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
Expand All @@ -69,7 +69,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()

val parikhBounds: ParikhBounds by lazy {
if (branches.isEmpty()) {
if ("ε" in root) mapOf() else mapOf(root to 1..1)
if (epsStr.isEmpty()) mapOf() else mapOf(root to 1..1)
} else branches.map { it.first.parikhBounds * it.second.parikhBounds }
.reduce(ParikhBounds::plus)
}
Expand All @@ -78,7 +78,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()

// Average time: 436.96ms, total time 43696.959ms (testRandomCFG)
private fun decodeString(i: BigInteger): Pair<String, BigInteger> {
if (branches.isEmpty()) return (if ("ε" in root) "" else root) to i
if (branches.isEmpty()) return (epsStr) to i
val (quotient1, remainder) = i.divrem(branches.size.toBigInteger())
val (lb, rb) = shuffledBranches[remainder.intValue()]
val (l, quotient2) = lb.decodeString(quotient1)
Expand All @@ -89,7 +89,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()

// Average time: 328.99ms, total time 32899.708ms (testRandomCFG)
private fun decodeStringFast(i: Long): Pair<String, Long> {
if (branches.isEmpty()) return (if ("ε" in root) "" else root) to i
if (branches.isEmpty()) return (epsStr) to i
val (quotient1, remainder) = i / branches.size.toLong() to (i % branches.size.toLong())
val (lb, rb) = shuffledBranches[remainder.toInt()]
val (l, quotient2) = lb.decodeStringFast(quotient1)
Expand Down Expand Up @@ -120,7 +120,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
while (i < 9 * totalTrees) yield(decodeString(i++ * stride + offset).first)
}

fun sampleStrWithPCFG5(pcfgTable: Map<Π5A<Σᐩ>, Int>): Sequence<String> =
fun sampleStrWithPCFG5(pcfgTable: Map<StrQuintuple, Int>): Sequence<String> =
sequence { while (true) yield(samplePCFG5(pcfgTable)) }

fun sampleStrWithPCFG3(pcfgTable: Map<Π3A<Σᐩ>, Int>): Sequence<String> =
Expand All @@ -142,32 +142,36 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
Tree(root, children = arrayOf(a, b))
}

val epsStr by lazy { if ("ε" in root) "" else root }
val dotEpsStr by lazy { if ("" in root) "" else root }

fun sample(): String =
if (branches.isEmpty()) if ("ε" in root) "" else root
if (branches.isEmpty()) epsStr
else branches.random().let { (l, r) ->
val (a, b) = l.sample() to r.sample()
if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
}

fun Σᐩ.name() = if ("~" in this) split("~")[1] else this
val triples : List<Π3A<Σᐩ>> by lazy { branches.map { root.name() to it.first.root.name() to it.second.root.name() } }
val rootName by lazy { root.name() }
val isLeaf by lazy { branches.isEmpty() }

fun samplePCFG5(pcfgTable: Map<Π5A<Σᐩ>, Int>, upUp: Σᐩ = "NIL", upLeft: Σᐩ = "NIL", upRight: Σᐩ = "NIL"): Σᐩ {
if (branches.isEmpty()) return if ("ε" in root) "" else root
val rt = root.name()
val probs = triples.map { (pcfgTable[upUp to upLeft to upRight to it.second to it.third] ?: 1) + 1 }
fun samplePCFG5(pcfgTable: Map<StrQuintuple, Int>, upUp: Σᐩ = "NIL", upLeft: Σᐩ = "NIL", upRight: Σᐩ = "NIL"): Σᐩ {
if (isLeaf) return epsStr
val probs = triples.map { (pcfgTable[StrQuintuple(upUp, upLeft, upRight, it.second, it.third)] ?: 1) + 1 }
val cdf = probs.runningReduce { acc, i -> acc + i }
val rnd = Random.nextInt(probs.sum())
val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it }
val (l, r) = branches[childIdx]
val (lr, rr) = l.root.name() to r.root.name()
val (a, b) = l.samplePCFG5(pcfgTable, rt, "$lr*", rr) to
r.samplePCFG5(pcfgTable, rt, lr, "$rr*")
val (lr, rr) = l.rootName to r.rootName
val (a, b) = l.samplePCFG5(pcfgTable, rootName, "$lr*", rr) to
r.samplePCFG5(pcfgTable, rootName, lr, "$rr*")
return if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
}

fun samplePCFG3(pcfgTable: Map<Π3A<Σᐩ>, Int>): Σᐩ {
if (branches.isEmpty()) return if ("ε" in root) "" else root
if (branches.isEmpty()) return epsStr

val probs = triples.map { (pcfgTable[it] ?: 1) + 1 }
val cdf = probs.runningReduce { acc, i -> acc + i }
Expand All @@ -182,7 +186,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()

// Prefers shorter strings, i.e., strings with more ε tokens
fun sampleStrWithGeomDecay(): String =
if (branches.isEmpty()) if ("" in root) "" else root
if (branches.isEmpty()) dotEpsStr
else {
// val p = 0.9 // Adjust this for different decay rates
// val rnd = Random.nextDouble()
Expand All @@ -202,6 +206,12 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
// }
}

data class StrQuintuple(val a: String, val b: String, val c: String, val d: String, val e: String) {
val hash = a.hashCode() + b.hashCode() + c.hashCode() + d.hashCode() + e.hashCode()
override fun hashCode(): Int = hash
override fun equals(other: Any?): Boolean = other is StrQuintuple && other.hash == hash
}

fun CFG.startPTree(tokens: List<String>) = //measureTimedValue {
initPForestMat(tokens).seekFixpoint().diagonals.last()[0][START_SYMBOL]
//}.also { println("Took ${it.duration} to compute parse forest") }.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import ai.hypergraph.kaliningraph.types.times
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger
import java.util.stream.*
import kotlin.math.max
import kotlin.streams.*
import kotlin.time.Duration.Companion.minutes
import kotlin.time.TimeSource
Expand Down Expand Up @@ -77,7 +76,7 @@ fun CFG.sampleDirectlyWR(
}

fun CFG.sampleWithPCFG(
pcfgTable: Map<Π5A<Σᐩ>, Int>,
pcfgTable: Map<StrQuintuple, Int>,
cores: Int = NUM_CORES,
stoppingCriterion: () -> Boolean = { true }
): Stream<String> =
Expand Down Expand Up @@ -204,8 +203,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {
println("Constructed ∩-grammar with $totalProds productions in ${clock.elapsedNow()}")

clock = TimeSource.Monotonic.markNow()
return Stream.concat(binaryProds.stream(),
(initFinal + transits + unitProds).stream()).parallel()
return Stream.concat(binaryProds.stream(), (initFinal + transits + unitProds).stream()).parallel()
.filter { (_, rhs) -> rhs.all { !it.isSyntheticNT() || it in nts } }
.collect(Collectors.toSet())
.jvmPostProcess(clock)
Expand Down

0 comments on commit ac7f61e

Please sign in to comment.