Skip to content

Commit

Permalink
try using markov chain to select candidate fragments
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Sep 22, 2023
1 parent 83bc0a6 commit df725b7
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.sampling.choose
import ai.hypergraph.kaliningraph.types.*
import kotlin.jvm.JvmName
import kotlin.reflect.KProperty
import kotlin.time.*

typealias Σᐩ = String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ fun FastRandomSet<Edit>.resample(maxTake: Int,

// Enumerates powerset levels from the bottom up, skipping the empty set
private fun Edit.subedits(): Sequence<Sequence<List<Pair<Int, Σᐩ>>>> =
(1..size).asSequence()
.map { choose(it).map { it.toList() } }
(1..size).asSequence().map { choose(it).map { it.toList() } }

fun List<Σᐩ>.apply(edit: Edit): List<Σᐩ> {
val res = toMutableList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*

// Returns all syntactically strings ordered by distance to withRespect
fun CFG.sortAll(s: Σᐩ, metric: ChoiceMetric): Set<Σᐩ> =
try { solveSortedFP(s.tokenizeByWhitespace(), metric)?.sorted()
?.map { it.asString }?.toSet() ?: setOf() }
catch (e: Exception) { e.printStackTrace(); setOf() }
fun CFG.solve(s: Σᐩ, metric: ChoiceMetric): Set<Σᐩ> =
solve(s.tokenizeByWhitespace(), metric)

fun CFG.solve(s: List<Σᐩ>, metric: ChoiceMetric): Set<Σᐩ> =
try { solveSortedFP(s, metric)?.sorted()?.map { it.asString }?.toSet() }
catch (e: Exception) { e.printStackTrace(); null } ?: setOf()

fun CFG.solveSortedFP(
tokens: List<Σᐩ>,
Expand Down Expand Up @@ -43,6 +45,7 @@ fun CFG.sortwiseAlgebra(metric: ChoiceMetric): Ring<Sort> =

const val MAX_CAPACITY = 100
// X ⊗ Z := { w | <x, z> ∈ X × Z, (w -> xz) ∈ P }
// Greedily selects candidate string fragments according to ChoiceMetric
fun CFG.join(X: Sort, Z: Sort, metric: ChoiceMetric = { it.weight }): Sort =
bimap.TRIPL.filter { (_, x, z) -> x in X && z in Z }
.map { (w, x, z) ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.hasBalancedBrackets
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.tensor.seekFixpoint
import ai.hypergraph.kaliningraph.types.π2
import kotlinx.datetime.Clock
Expand Down Expand Up @@ -363,7 +363,7 @@ class SetValiantTest {
val holExpr = "_ _ _ _ _ _ _ _ _ _"

measureTime {
val solutions = ocamlCFG.sortAll(holExpr, levMetric("( false curry )"))
val solutions = ocamlCFG.solve(holExpr, levMetric("( false curry )"))
println("Found: ${solutions.size} unique solutions")
solutions.forEach { println(it); assertTrue("$it was invalid!") { ocamlCFG.isValid(it) } }
}.also { println("Finished in ${it.inWholeMilliseconds}ms.") }
Expand Down Expand Up @@ -569,11 +569,12 @@ Yield_Arg -> From_Keyword Test | Testlist_Endcomma
*/
@Test
fun testPythonRepairs() {
val reference = "NAME = ( NAME"
val template = "_ _ _ _ _ _ _"//List(it.size + 2) { "_" }.joinToString(" ")
val refStr = "NAME = ( NAME"
val refLst = refStr.tokenizeByWhitespace()
val template = List(refLst.size + 3) { "_" }.joinToString(" ")
measureTime {
seq2parsePythonCFG.sortAll(template, levMetric(reference))
.onEach { println(it) }
seq2parsePythonCFG.solve(template, levMetric(refStr))
.onEach { println("Δ=${levenshtein(it, refStr)}: $it") }
.also { println("Found ${it.size} solutions!") }
}.also { println("Finished in ${it.inWholeMilliseconds}ms.") }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ open class MarkovChain<T>(

// Computes perplexity of a sequence normalized by sequence length
fun score(seq: List<T>): Double =
-seq.windowed(memory)
if (memory < seq.size) -seq.windowed(memory)
.map { (getAtLeastOne(it) + 1) / (getAtLeastOne(it.dropLast(1) + null) + dictionary.size) }
.sumOf { ln(it) } / seq.size
else (seq.sumOf { counter.rawCounts.getEstimate(it) } + 1).toDouble() / counter.total.toDouble()

operator fun get(vararg variables: T?): Double =
if (variables.size == 1) counter.rawCounts.getEstimate(variables[0]) / counter.total.toDouble()
Expand Down

0 comments on commit df725b7

Please sign in to comment.