From ac7f61ed8547d7784b6729506af7e478ef56ae17 Mon Sep 17 00:00:00 2001 From: breandan Date: Sat, 24 Feb 2024 16:55:02 -0500 Subject: [PATCH] speed up PCFG sampler --- .../kaliningraph/parsing/SeqValiant.kt | 42 ++++++++++++------- .../kaliningraph/parsing/JVMBarHillel.kt | 6 +-- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt index 361d43e2..74a48d5f 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt @@ -7,10 +7,10 @@ import ai.hypergraph.kaliningraph.tensor.UTMatrix import ai.hypergraph.kaliningraph.types.* import com.ionspin.kotlin.bignum.integer.* import kotlin.jvm.JvmName -import kotlin.math.* import kotlin.random.Random import kotlin.time.measureTimedValue + // Indexes a set of PTrees by their roots typealias PForest = Map // ℙ₃ // Algebraic data type / polynomial functor for parse forests (ℙ₂) @@ -59,7 +59,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() } private val choice by lazy { - if (branches.isEmpty()) listOf(if ("ε" in root) "" else root) + if (branches.isEmpty()) listOf(epsStr) else shuffledBranches.flatMap { (l, r) -> (l.choose() * r.choose()).map { (a, b) -> if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b" @@ -69,7 +69,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() val parikhBounds: ParikhBounds by lazy { if (branches.isEmpty()) { - if ("ε" in root) mapOf() else mapOf(root to 1..1) + if (epsStr.isEmpty()) mapOf() else mapOf(root to 1..1) } else branches.map { it.first.parikhBounds * it.second.parikhBounds } .reduce(ParikhBounds::plus) } @@ -78,7 +78,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() // Average time: 436.96ms, total time 43696.959ms (testRandomCFG) private fun decodeString(i: BigInteger): Pair { - if (branches.isEmpty()) return (if ("ε" in root) "" else root) to i + if (branches.isEmpty()) return (epsStr) to i val (quotient1, remainder) = i.divrem(branches.size.toBigInteger()) val (lb, rb) = shuffledBranches[remainder.intValue()] val (l, quotient2) = lb.decodeString(quotient1) @@ -89,7 +89,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() // Average time: 328.99ms, total time 32899.708ms (testRandomCFG) private fun decodeStringFast(i: Long): Pair { - if (branches.isEmpty()) return (if ("ε" in root) "" else root) to i + if (branches.isEmpty()) return (epsStr) to i val (quotient1, remainder) = i / branches.size.toLong() to (i % branches.size.toLong()) val (lb, rb) = shuffledBranches[remainder.toInt()] val (l, quotient2) = lb.decodeStringFast(quotient1) @@ -120,7 +120,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() while (i < 9 * totalTrees) yield(decodeString(i++ * stride + offset).first) } - fun sampleStrWithPCFG5(pcfgTable: Map<Π5A<Σᐩ>, Int>): Sequence = + fun sampleStrWithPCFG5(pcfgTable: Map): Sequence = sequence { while (true) yield(samplePCFG5(pcfgTable)) } fun sampleStrWithPCFG3(pcfgTable: Map<Π3A<Σᐩ>, Int>): Sequence = @@ -142,8 +142,11 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() Tree(root, children = arrayOf(a, b)) } + val epsStr by lazy { if ("ε" in root) "" else root } + val dotEpsStr by lazy { if (".ε" in root) "" else root } + fun sample(): String = - if (branches.isEmpty()) if ("ε" in root) "" else root + if (branches.isEmpty()) epsStr else branches.random().let { (l, r) -> val (a, b) = l.sample() to r.sample() if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b" @@ -151,23 +154,24 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() fun Σᐩ.name() = if ("~" in this) split("~")[1] else this val triples : List<Π3A<Σᐩ>> by lazy { branches.map { root.name() to it.first.root.name() to it.second.root.name() } } + val rootName by lazy { root.name() } + val isLeaf by lazy { branches.isEmpty() } - fun samplePCFG5(pcfgTable: Map<Π5A<Σᐩ>, Int>, upUp: Σᐩ = "NIL", upLeft: Σᐩ = "NIL", upRight: Σᐩ = "NIL"): Σᐩ { - if (branches.isEmpty()) return if ("ε" in root) "" else root - val rt = root.name() - val probs = triples.map { (pcfgTable[upUp to upLeft to upRight to it.second to it.third] ?: 1) + 1 } + fun samplePCFG5(pcfgTable: Map, upUp: Σᐩ = "NIL", upLeft: Σᐩ = "NIL", upRight: Σᐩ = "NIL"): Σᐩ { + if (isLeaf) return epsStr + val probs = triples.map { (pcfgTable[StrQuintuple(upUp, upLeft, upRight, it.second, it.third)] ?: 1) + 1 } val cdf = probs.runningReduce { acc, i -> acc + i } val rnd = Random.nextInt(probs.sum()) val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it } val (l, r) = branches[childIdx] - val (lr, rr) = l.root.name() to r.root.name() - val (a, b) = l.samplePCFG5(pcfgTable, rt, "$lr*", rr) to - r.samplePCFG5(pcfgTable, rt, lr, "$rr*") + val (lr, rr) = l.rootName to r.rootName + val (a, b) = l.samplePCFG5(pcfgTable, rootName, "$lr*", rr) to + r.samplePCFG5(pcfgTable, rootName, lr, "$rr*") return if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b" } fun samplePCFG3(pcfgTable: Map<Π3A<Σᐩ>, Int>): Σᐩ { - if (branches.isEmpty()) return if ("ε" in root) "" else root + if (branches.isEmpty()) return epsStr val probs = triples.map { (pcfgTable[it] ?: 1) + 1 } val cdf = probs.runningReduce { acc, i -> acc + i } @@ -182,7 +186,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() // Prefers shorter strings, i.e., strings with more ε tokens fun sampleStrWithGeomDecay(): String = - if (branches.isEmpty()) if (".ε" in root) "" else root + if (branches.isEmpty()) dotEpsStr else { // val p = 0.9 // Adjust this for different decay rates // val rnd = Random.nextDouble() @@ -202,6 +206,12 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() // } } +data class StrQuintuple(val a: String, val b: String, val c: String, val d: String, val e: String) { + val hash = a.hashCode() + b.hashCode() + c.hashCode() + d.hashCode() + e.hashCode() + override fun hashCode(): Int = hash + override fun equals(other: Any?): Boolean = other is StrQuintuple && other.hash == hash +} + fun CFG.startPTree(tokens: List) = //measureTimedValue { initPForestMat(tokens).seekFixpoint().diagonals.last()[0][START_SYMBOL] //}.also { println("Took ${it.duration} to compute parse forest") }.value diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt index 5f06e4e5..b42c67d1 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt @@ -9,7 +9,6 @@ import ai.hypergraph.kaliningraph.types.times import java.util.concurrent.* import java.util.concurrent.atomic.AtomicInteger import java.util.stream.* -import kotlin.math.max import kotlin.streams.* import kotlin.time.Duration.Companion.minutes import kotlin.time.TimeSource @@ -77,7 +76,7 @@ fun CFG.sampleDirectlyWR( } fun CFG.sampleWithPCFG( - pcfgTable: Map<Π5A<Σᐩ>, Int>, + pcfgTable: Map, cores: Int = NUM_CORES, stoppingCriterion: () -> Boolean = { true } ): Stream = @@ -204,8 +203,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG { println("Constructed ∩-grammar with $totalProds productions in ${clock.elapsedNow()}") clock = TimeSource.Monotonic.markNow() - return Stream.concat(binaryProds.stream(), - (initFinal + transits + unitProds).stream()).parallel() + return Stream.concat(binaryProds.stream(), (initFinal + transits + unitProds).stream()).parallel() .filter { (_, rhs) -> rhs.all { !it.isSyntheticNT() || it in nts } } .collect(Collectors.toSet()) .jvmPostProcess(clock)