Skip to content

Commit

Permalink
describe intent more faithfully
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jan 4, 2024
1 parent 87be69c commit d589ca1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,16 +289,16 @@ fun CFG.enumSeqSmart(tokens: List<String>): Sequence<String> =
fun CFG.sampleSeq(tokens: List<String>): Sequence<String> =
startPTree(tokens)?.sampleWithReplacement() ?: sequenceOf()

fun CFG.enumTree(tokens: List<String>): Sequence<Tree> =
fun CFG.enumTrees(tokens: List<String>): Sequence<Tree> =
startPTree(tokens)?.sampleTreesWithoutReplacement() ?: sequenceOf()

fun CFG.enumSWOR(tokens: List<String>): Sequence<String> =
fun CFG.sampleSWOR(tokens: List<String>): Sequence<String> =
startPTree(tokens)?.sampleWRGD() ?: sequenceOf()

fun CFG.hammingBallRepair(tokens: List<String>): Sequence<String> =
tokens.indices.toSet().choose(5)
.map { tokens.substituteIndices(it) { it, i -> "_" } }
.flatMap { enumSWOR(it).take(100) }
.flatMap { sampleSWOR(it).take(100) }

fun CFG.repairSeq(tokens: List<String>): Sequence<String> =
tokens.intersperse(2, "ε").let { prompt ->
Expand All @@ -315,32 +315,33 @@ fun CFG.fastRepairSeq(tokens: List<String>, spacing: Int = 2, holes: Int = 6): S
prompt.indices.toSet().choose(minOf(holes, prompt.size - 1))
.map { prompt.substituteIndices(it) { _, _ -> "_" } }
// ifEmpty {...} is a hack to ensure the sequence emits values at a steady frequency
.flatMap { enumSWOR(it).take(100).ifEmpty { sequenceOf("ε") } }
.flatMap { sampleSWOR(it).take(100).ifEmpty { sequenceOf("ε") } }
.map { it.removeEpsilon() }
}.flatMap { if (it.isEmpty()) sequenceOf(it) else minimizeFix(tokens, it.tokenizeByWhitespace()) { this in language } }

// Note the repairs are not distinct as we try to avoid long delays between
// repairs, so callees must remember to append .distinct() if they want this.
fun CFG.fasterRepairSeq(tokens: List<String>, spacing: Int = 2, holes: Int = 6): Sequence<String> {
var levenshteinBlanket = tokens
var blanketSeq = emptySequence<String>().iterator()
val uniformSeq = tokens.intersperse(spacing, "ε").let { prompt ->
prompt.indices.toSet().choose(minOf(holes, prompt.size - 1))
.map { prompt.substituteIndices(it) { _, _ -> "_" } }
// ifEmpty {...} is a hack to ensure the sequence emits values at a steady frequency
.flatMap { enumSWOR(it).take(100).ifEmpty { sequenceOf("ε") } }
.flatMap { sampleSWOR(it).take(100).ifEmpty { sequenceOf("ε") } }
}.iterator()

val distinct1 = mutableSetOf<String>()
val distinct2 = mutableSetOf<String>()

return generateSequence {
if (blanketSeq.hasNext() && Random.nextBoolean()) blanketSeq.next()//.also { println("Blanket: $it") }
else if (uniformSeq.hasNext()) uniformSeq.next()//.also { println("Uniform: $it") }
if (blanketSeq.hasNext() && Random.nextBoolean()) blanketSeq.next()
else if (uniformSeq.hasNext()) uniformSeq.next()
else null
}.map { it.removeEpsilon() }.flatMap {
if (it.isEmpty() || it in distinct1) sequenceOf(it)
else {
distinct1.add(it)
minimizeFix(tokens, it.tokenizeByWhitespace()) { this in language }.onEach { minfix ->
if (it.isEmpty() || !distinct1.add(it)) sequenceOf(it)
else minimizeFix(tokens, it.tokenizeByWhitespace()) { this in language }
.onEach { minfix ->
if (minfix !in distinct2) {
distinct2.add(minfix)
val newBlanket =
Expand All @@ -352,7 +353,6 @@ fun CFG.fasterRepairSeq(tokens: List<String>, spacing: Int = 2, holes: Int = 6):
}
}
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package ai.hypergraph.kaliningraph.parsing
import Grammars
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.automata.parseFSA
import ai.hypergraph.kaliningraph.repair.*
import ai.hypergraph.kaliningraph.sampling.all
import kotlin.test.*
import kotlin.time.*

Expand Down Expand Up @@ -203,11 +201,11 @@ class BarHillelTest {
val gram = Grammars.ifThen
val origStr = "if ( true or false then true else 1"
val tokens = origStr.tokenizeByWhitespace()
val levDist = 3
val levBall = makeLevFSA(origStr, levDist, gram.terminals)
val maxLevDist = 3
val levBall = makeLevFSA(origStr, maxLevDist, gram.terminals)
val intGram = gram.intersectLevFSA(levBall)
val clock = TimeSource.Monotonic.markNow()
val template = List(tokens.size + levDist) { "_" }
val template = List(tokens.size + maxLevDist) { "_" }
val lbhSet = intGram.enumSeqMinimal(template, tokens)
.onEachIndexed { i, it ->
if (i < 100) {
Expand All @@ -218,22 +216,23 @@ class BarHillelTest {
val actDist= levenshtein(origStr, it)
assertTrue(it in gram.language)
assertTrue(levBall.recognizes(it))
assertTrue(actDist <= levDist)
assertTrue(actDist <= maxLevDist)
}.toSet()
// Found 221 minimal solutions using Levenshtein/Bar-Hillel in 23.28s
.also { println("Found ${it.size} minimal solutions using " +
"Levenshtein/Bar-Hillel in ${clock.elapsedNow()}") }

val prbSet = Grammars.ifThen.fasterRepairSeq(tokens, 1, 3)
.onEachIndexed { i, it ->
.distinct().mapNotNull {
val levDistance = levenshtein(origStr, it)
if (levDistance < levDist) {
if (levDistance < maxLevDist) {
println("Found ($levDistance): " + levenshteinAlign(origStr, it).paintANSIColors())
assertTrue(it in Grammars.ifThen.language)
assertTrue(levBall.recognizes(it))
assertTrue(it in intGram.language)
assertTrue(it in lbhSet)
}
it
} else null
}.toList()
.also { println("Found ${it.size} minimal solutions using " +
"Probabilistic repair in ${clock.elapsedNow()}") }
Expand All @@ -247,35 +246,28 @@ class BarHillelTest {
val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val origStr = "NAME = ( NAME . NAME ( NAME NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val levDist = 3
val levBall = makeLevFSA(toRepair, levDist, gram.terminals)
val maxLevDist = 3
val levBall = makeLevFSA(toRepair, maxLevDist, gram.terminals)
// println(levBall.toDot())
// throw Exception("")
val intGram = gram.intersectLevFSA(levBall)
// val part= intGram.nonterminals.map { it.substringAfter(',')
// .substringBefore(',') }.toSet().filter { it in gram.nonterminals }
//
// println("Part: $part")
// println("Nopart: ${gram.nonterminals - part}")

// .also { println("LEV ∩ CFG grammar:\n${it.pretty}") }
// println(intGram.prettyPrint())
val clock = TimeSource.Monotonic.markNow()

val template = List(toRepair.size + levDist - 1) { "_" }
val template = List(toRepair.size + maxLevDist - 1) { "_" }

val lbhSet = intGram.enumSeqMinimal(template, toRepair)
.onEachIndexed { i, it ->
if (i < 100) {
val levAlign = levenshteinAlign(origStr, it).paintANSIColors()
println(levAlign)
val pf = intGram.enumTree(it.tokenizeByWhitespace()).toList()
val pf = intGram.enumTrees(it.tokenizeByWhitespace()).toList()
println("Found " + pf.size + " parse trees")
println(pf.first().prettyPrint())
println("\n\n")
}

assertTrue(levenshtein(origStr, it) <= levDist)
assertTrue(levenshtein(origStr, it) <= maxLevDist)
assertTrue(it in gram.language)
assertTrue(levBall.recognizes(it))
}.toSet()
Expand All @@ -288,16 +280,17 @@ class BarHillelTest {

val s2pg = Grammars.seq2parsePythonCFG
val prbSet = s2pg.fasterRepairSeq(toRepair, 1, 3)
.onEachIndexed { i, it ->
.distinct().mapIndexedNotNull { i, it ->
val levDistance = levenshtein(origStr, it)
if (i < 100) println("Found ($levDistance): " + levenshteinAlign(origStr, it).paintANSIColors())
if (levDistance < levDist) {
if (levDistance < maxLevDist) {
println("Checking: $it")
assertTrue(it in s2pg.language)
assertTrue(levBall.recognizes(it))
assertTrue(it in intGram.language)
assertTrue(it in lbhSet)
}
it
} else null
}.toList()
// Found 3912 minimal solutions using Probabilistic repair in 11m 51.535605250s
.also { println("Found ${it.size} minimal solutions using " +
Expand Down Expand Up @@ -340,16 +333,17 @@ class BarHillelTest {
"Levenshtein/Bar-Hillel in ${clock.elapsedNow()}") }

val s2pg = Grammars.seq2parsePythonCFG
val prbSet = s2pg.fasterRepairSeq(toRepair, 1, 2)
.onEachIndexed { i, it ->
val prbSet = s2pg.fasterRepairSeq(toRepair, 1, 2).distinct()
.mapIndexedNotNull { i, it ->
val levDistance = levenshtein(origStr, it)
if (levDistance < levDist) {
println("Found ($levDistance): " + levenshteinAlign(origStr, it).paintANSIColors())
assertTrue(it in s2pg.language)
assertTrue(levBall.recognizes(it))
assertTrue(it in intGram.language)
assertTrue(it in lbhSet)
}
it
} else null
}.toList()
.also { println("Found ${it.size} minimal solutions using " +
"Probabilistic repair in ${clock.elapsedNow()}") }
Expand All @@ -375,13 +369,14 @@ class BarHillelTest {

val s2pg = Grammars.seq2parsePythonCFG
s2pg.fasterRepairSeq(toRepair, 1, 2).distinct()
.onEachIndexed { i, it ->
.mapIndexedNotNull { i, it ->
val levDistance = levenshtein(origStr, it)
if (levDistance <= levDist) {
println("Found ($levDistance): " + levenshteinAlign(origStr, it).paintANSIColors())
assertTrue(it in s2pg.language)
assertTrue(levBall.recognizes(it))
}
it
} else null
}.takeWhile { clock.elapsedNow().inWholeSeconds < 30 }.toList()
.also { println("Found ${it.size} minimal solutions using " +
"Probabilistic repair in ${clock.elapsedNow()}") }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class SeqValiantTest {
val refStr = "while ( <term> ) ;"
tinyC.parse(refStr)?.let { println(it) }
println(refStr in tinyC.language)
tinyC.fasterRepairSeq(refStr.tokenizeByWhitespace()).take(100).forEach {
tinyC.fasterRepairSeq(refStr.tokenizeByWhitespace()).distinct().take(100).forEach {
println(it)
assertTrue(it in tinyC.language, "Invalid solution: $it")
}
Expand Down

0 comments on commit d589ca1

Please sign in to comment.