Skip to content

Commit

Permalink
implement PCFG scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Feb 20, 2024
1 parent 400d8e1 commit d7321f3
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ class BiMap(cfg: CFG) {

// n.b., this only works if the CFG is acyclic, i.e., finite otherwise it will loop forever
fun CFG.toPTree(from: Σᐩ = START_SYMBOL): PTree =
PTree(from, bimap[from].map { toPTree(it[0]) to if(it.size == 1) PTree() else toPTree(it[1]) })
PTree(from, bimap[from].map { toPTree(it[0]) to if (it.size == 1) PTree() else toPTree(it[1]) })

/*
Γ ⊢ ∀ v.[α→*]∈G ⇒ α→[β] "If all productions rooted at α
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
}

fun sampleStrWithPCFG(pcfgTable: Map<Π3A<Σᐩ>, Int>): Sequence<String> =
sequence { while (true) yield(samplePCFG(pcfgTable)) }
sequence { while (true) yield(samplePCFG3(pcfgTable)) }

// Samples instantaneously from the parse forest, but may return duplicates
// and only returns a fraction of the number of distinct strings when compared
Expand Down Expand Up @@ -149,15 +149,28 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
fun Σᐩ.name() = if ("~" in this) split("~")[1] else this
val triples by lazy { branches.map { root.name() to it.first.root.name() to it.second.root.name() } }

fun samplePCFG(pcfgTable: Map<Π3A<Σᐩ>, Int>): Σᐩ {
fun samplePCFG5(pcfgTable: Map<Π5A<Σᐩ>, Int>, root: Σᐩ = "NIL", upLeft: Σᐩ = "NIL", upRight: Σᐩ = "NIL"): Σᐩ {
if (branches.isEmpty()) return if ("ε" in root) "" else root
val rt = root.name()
val probs = triples.map { (pcfgTable[rt to upLeft to upRight to it.second to it.third] ?: 1) + 1 }
val cdf = probs.runningReduce { acc, i -> acc + i }
val rnd = Random.nextInt(probs.sum())
val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it }
val (l, r) = branches[childIdx]
val (lr, rr) = l.root.name() to r.root.name()
val (a, b) = l.samplePCFG5(pcfgTable, rt, lr, rr) to r.samplePCFG5(pcfgTable, rt, lr, rr)
return if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
}

fun samplePCFG3(pcfgTable: Map<Π3A<Σᐩ>, Int>): Σᐩ {
if (branches.isEmpty()) return if ("ε" in root) "" else root

val probs = triples.map { (pcfgTable[it] ?: 1) + 1 }
val cdf = probs.runningReduce { acc, i -> acc + i }
val rnd = Random.nextInt(probs.sum())
val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it }
val (l, r) = branches[childIdx]
val (a, b) = l.samplePCFG(pcfgTable) to r.samplePCFG(pcfgTable)
val (a, b) = l.samplePCFG3(pcfgTable) to r.samplePCFG3(pcfgTable)
return if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
}

Expand Down
13 changes: 12 additions & 1 deletion src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Tree.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package ai.hypergraph.kaliningraph.parsing
import ai.hypergraph.kaliningraph.graphs.LGVertex
import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.tensor.FreeMatrix
import ai.hypergraph.kaliningraph.types.Π3A
import ai.hypergraph.kaliningraph.types.*

typealias TreeMatrix = FreeMatrix<Forest>
typealias Forest = Set<Tree>
Expand Down Expand Up @@ -33,6 +33,17 @@ class Tree constructor(
else listOf(Π3A(root, children[0].root, children[1].root)) +
children.flatMap { it.triples() }

fun quintuples(parent: String = "NIL", lsibling: String = "NIL", rsibling: String = "NIL"): List<Π5A<Σᐩ>> =
if (children.size != 2) listOf()
else listOf(Π5A(parent, lsibling, rsibling, children[0].root, children[1].root)) +
children[0].quintuples(root, children[0].root, children[1].root) +
children[1].quintuples(root, children[0].root, children[1].root)

fun logProb(pcfgMap: Map<Π3A<Σᐩ>, Int>): Double =
if (children.isEmpty()) 0.0
else (pcfgMap[root to children[0].root to children[1].root]?.toDouble() ?: 0.0) +
children.sumOf { it.logProb(pcfgMap) }

fun toGraph(j: Σᐩ = "0"): LabeledGraph =
LabeledGraph { LGVertex(root, "$root.$j").let { it - it } } +
children.foldIndexed(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import kotlin.jvm.JvmName
typealias Π2A<A> = Π2<A, A>
typealias Π3A<A> = Π3<A, A, A>
typealias Π4A<A> = Π4<A, A, A, A>
typealias Π5A<A> = Π5<A, A, A, A, A>

// Multimorphic arrays
data class Π1<A>(val π1: A)/*: V1<A> by VT(π1)*/
Expand Down

0 comments on commit d7321f3

Please sign in to comment.