Skip to content

Commit

Permalink
start of egg stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
vezwork committed Aug 2, 2024
1 parent b6a10f5 commit c746f2a
Show file tree
Hide file tree
Showing 12 changed files with 924 additions and 111 deletions.
1 change: 1 addition & 0 deletions dist/demo/2024_07/equality_saturation.js
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"use strict";
259 changes: 203 additions & 56 deletions dist/demo/2024_07/rewrite.js
Original file line number Diff line number Diff line change
@@ -1,27 +1,135 @@
"use strict";
const match = (pattern, structure) => {
if (structure[0] !== pattern[0])
return false;
const a = Array.isArray(pattern[1])
? match(pattern[1], structure[1])
: { [pattern[1]]: structure[1] };
if (!a)
return false;
const b = Array.isArray(pattern[2])
? match(pattern[2], structure[2])
: { [pattern[2]]: structure[2] };
if (!b)
return false;
return {
...a,
...b,
};
};
const subst = (matches, [op, a, b]) => [
op,
Array.isArray(a) ? subst(matches, a) : matches[a],
Array.isArray(b) ? subst(matches, b) : matches[b],
const variable = ({ name, pred, match }) => ({
kind: "variable",
name,
pred,
match,
withPred: (newPred) => variable({ name, pred: newPred, match }),
withMatch: (newMatch) => variable({ name, pred, match: newMatch }),
});
const v = ([name]) => variable({ name, pred: () => true, match: (matches) => matches[name] });
const isNumber = (n) => typeof n === "number";
const MY_RULES = [
{
from: ["=", ["+", v `a`, v `b`], v `c`],
to: ["=", v `a`, ["-", v `c`, v `b`]],
},
{
from: ["=", ["-", v `a`, v `b`], v `c`],
to: ["=", v `a`, ["+", v `c`, v `b`]],
},
{
from: ["=", v `a`, v `b`],
to: ["=", v `b`, v `a`],
},
{
from: ["+", v `a`, v `b`],
to: ["+", v `b`, v `a`],
},
{
from: ["+", v `a`.withPred(isNumber), v `b`.withPred(isNumber)],
to: v `a + b`.withMatch(({ a, b }) => a + b),
},
{
from: ["+", v `a`, v `a`],
to: ["*", 2, v `a`],
},
{
from: ["*", v `a`, v `b`],
to: ["*", v `b`, v `a`],
},
{
from: ["=", ["*", v `a`, v `b`], v `c`],
to: ["=", v `a`, ["/", v `c`, v `b`]],
},
{
from: ["=", ["/", v `a`, v `b`], v `c`],
to: ["=", v `a`, ["*", v `c`, v `b`]],
},
];
const MATCH_FAIL = "MATCH_FAIL";
const match = (pattern, structure, matchesSoFar = {}) => {
if (pattern.kind === "variable") {
if (!pattern.pred(structure, matchesSoFar))
return MATCH_FAIL;
const prevMatch = pattern.match(matchesSoFar);
if (prevMatch && JSON.stringify(prevMatch) !== JSON.stringify(structure))
return MATCH_FAIL;
return { [pattern.name]: structure };
}
else if (Array.isArray(pattern)) {
if (!Array.isArray(structure))
return MATCH_FAIL;
let matches = {};
for (let i = 0; i < pattern.length; i++) {
const p = pattern[i];
const s = structure[i];
const m = match(p, s, matches);
if (m === MATCH_FAIL)
return MATCH_FAIL;
matches = { ...matches, ...m };
}
return matches;
}
else {
if (pattern !== structure)
return MATCH_FAIL;
return {};
}
};
const getRecursiveMatches = (pattern, structure, res = new Map()) => {
res.set(structure, match(pattern, structure));
if (Array.isArray(structure))
for (const substructure of structure)
getRecursiveMatches(pattern, substructure, res);
return res;
};
const SUBST_FAIL = "SUBST_FAIL";
const subst = (matches, structure) => {
if (structure.kind === "variable")
return structure.match(matches) ?? SUBST_FAIL;
else if (Array.isArray(structure)) {
const ms = structure.map((s) => subst(matches, s));
if (ms.some((m) => m === SUBST_FAIL))
return SUBST_FAIL;
return ms;
}
else
return structure;
};
const APPLY_FAIL = "APPLY_FAIL";
const applyRule = ({ from, to, success, failure }, structure) => {
const m = match(from, structure);
if (m === MATCH_FAIL) {
if (failure)
failure(MATCH_FAIL);
return APPLY_FAIL;
}
const s = subst(m, to);
if (s === SUBST_FAIL) {
if (failure)
failure(SUBST_FAIL);
return APPLY_FAIL;
}
if (success)
success();
return s;
};
const tryAllRulesRecursively = (rules, structure) => {
return [
...(Array.isArray(structure)
? structure.flatMap((s, i) => tryAllRulesRecursively(rules, s).map((res) => structure.with(i, res)))
: []),
...rules
.map((r) => {
const res = applyRule(r, structure);
if (res === APPLY_FAIL)
return APPLY_FAIL;
return res;
})
.filter((subRes) => subRes !== APPLY_FAIL),
];
};
const rewriteStep = (input, rules) => {
const neue = [
...(Array.isArray(input[1])
Expand All @@ -31,44 +139,83 @@ const rewriteStep = (input, rules) => {
? rewriteStep(input[2], rules).map((res) => [input[0], input[1], res])
: []),
];
for (const { from, to } of rules) {
const m = match(from, input);
if (!m)
for (const rule of rules) {
const res = applyRule(rule, input);
if (res === APPLY_FAIL)
continue;
neue.push(subst(m, to));
neue.push(res);
}
return neue;
};
const rules = [
{
from: ["=", ["+", "a", "b"], "c"],
to: ["=", "a", ["-", "c", "b"]],
},
{
from: ["=", ["-", "a", "b"], "c"],
to: ["=", "a", ["+", "c", "b"]],
},
{
from: ["=", "a", "b"],
to: ["=", "b", "a"],
},
//const myInput = ["=", ["-", "w", 2], 3];
const myInput1 = ["=", `w`, ["-", `r`, `l`]];
const myInput2 = ["=", `c`, ["+", `l`, ["/", `w`, 2]]];
// GOAL:
// input: l, r, w = r - l, c = l + w/2
// output:
// - given: solution for l and r, given w and c.
const isVar = (x) => x === "w" || x === "r" || x === "l" || x === "c";
const sToRule = (s) => {
const m = match(["=", v `lhs`.withPred(isVar), v `rhs`], s);
if (m === MATCH_FAIL)
return [];
return [
{
m_name: m.lhs,
from: v `lhs`.withPred((x) => x === m.lhs),
to: m.rhs,
success: () => console.log("SUCCESSFULLY APPLIED", m.lhs),
failure: (fa) => console.log("FAIL APPLIED", fa, m.lhs),
},
];
};
console.log("tryAllRulesRecursively 1", tryAllRulesRecursively(MY_RULES, "a"), tryAllRulesRecursively(MY_RULES, ["a"]));
console.log("tryAllRulesRecursively 2", tryAllRulesRecursively(MY_RULES, ["+", "a", "b"]));
console.log("sToRule 1", sToRule(["=", v `x`, ["+", 1, 2]]));
console.log("sToRule 2", sToRule(["=", ["sin", v `x`], ["+", 1, 2]]));
console.log("sToRule 3", tryAllRulesRecursively([
{
from: ["+", "a", "b"],
to: ["+", "b", "a"],
from: v `lhs`.withPred((x) => x.kind === "variable" && x.name === "x"),
to: 22,
},
/* next: {
from: ["-", { name: "a", pred: isNum }, { name: "b", isNum }],
to: { args: ["a", "b"], calc: (a, b)=>a-b },
},*/
];
const mySet = new Set();
const see = (ob) => mySet.add(JSON.stringify(ob));
const isSeen = (ob) => mySet.has(JSON.stringify(ob));
const myInput = ["=", ["-", "w", 2], 3];
const toTry = [myInput];
while (toTry.length > 0) {
const i = toTry.pop();
see(i);
toTry.push(...rewriteStep(i, rules).filter((res) => !isSeen(res)));
], ["sin", v `x`]), tryAllRulesRecursively(sToRule(["=", v `x`, ["+", 1, 2]]), [
"+",
33,
["-", 1, v `x`],
]));
const mySet1 = new Set();
const mySet2 = new Set();
const see = (set, ob) => set.add(JSON.stringify(ob));
const isSeen = (set, ob) => set.has(JSON.stringify(ob));
const rules = [...MY_RULES];
const toTry1 = [myInput1];
const toTry2 = [myInput2];
for (let i = 0; i < 12; i++) {
if (toTry1.length === 0 && toTry2.length === 0)
break;
if (toTry1.length > 0) {
const i = toTry1.pop();
see(mySet1, i);
const newStructuresToTry = tryAllRulesRecursively(rules, i).filter((res) => !isSeen(mySet1, res));
toTry1.push(...newStructuresToTry);
rules.push(...newStructuresToTry.flatMap(sToRule));
}
if (toTry2.length > 0) {
const i = toTry2.pop();
see(mySet2, i);
const newStructuresToTry = tryAllRulesRecursively(rules, i).filter((res) => !isSeen(mySet2, res));
toTry2.push(...newStructuresToTry);
rules.push(...newStructuresToTry.flatMap(sToRule));
}
console.log("hi", toTry1.length);
}
console.log([...mySet.values()].map(JSON.parse));
console.log("go go go", [...mySet1.values()].map(JSON.parse), [...mySet2.values()].map(JSON.parse), rules);
// console.log(
// "SELECT",
// [...mySet.values()]
// .map(JSON.parse)
// .filter(
// ([[_1, v1], [_2, v2]]) => v1.kind === "variable" && v2.kind === "variable"
// )
// .map(([[_1, v1], [_2, v2]]) => [v1.name, v2.name])
// );
1 change: 1 addition & 0 deletions dist/demo/2024_08/e_graph.html
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<script src="./e_graph.js" type="module"></script>
85 changes: 85 additions & 0 deletions dist/demo/2024_08/e_graph.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// e init (tree) => e-graph
// e match (pattern, e-graph) => e-node[]
// e sub (result, e-node, e-graph) => e-graph

const node = (value, ...children) => {
const res = {
isNode: true,
value,
children: children,
parent: null,
};
for (const child of children) child.parent = res;
return res;
};
const vari = (v, ...children) => {
const res = {
isNode: true,
var: v,
children: children,
parent: null,
};
for (const child of children) child.parent = res;
return res;
};
const nodeEq = (n1, n2) => {
if (n1.value !== n2.value) return false;
else if (n1.children.length !== n2.children.length) return false;
else return n1.children.every((v, i) => nodeEq(v, n2.children[i]));
};

const a = node("f", node(1), node(2));
const b = node("f", node(1), node(2));

console.log("nodeEq!", nodeEq(a, b));

const eClassFromNode = (node, parentENode = null) => {
const eNode = { isENode: true, value: node.value };
eNode.children = node.children.map((n) => eClassFromNode(n, eNode));
return {
isEClass: true,
eNodes: [eNode],
parents: [parentENode],
};
};

console.log("eClassFromNode!", eClassFromNode(a));

const eClassMatches = (patternNode, eClass) => {
return eClass.eNodes.flatMap((en) => eNodeMatches(patternNode, en));
};
const eNodeMatches = (patternNode, eNode) => {
if (patternNode.var === undefined && eNode.value !== patternNode.value)
return [];
else if (patternNode.children.length !== eNode.children.length) return [];
else {
const childrenMatches = eNode.children.map((ec, i) =>
eClassMatches(patternNode.children[i], ec)
);
return [
...gogo(
childrenMatches,
patternNode.var ? { [patternNode.var]: eNode.value } : {}
),
];
}
};

const gogo = function* (childrenMatches, match) {
if (childrenMatches.length === 0) {
yield { ...match };
return;
}
for (const matches1 of childrenMatches[0]) {
for (const matches2 of gogo(childrenMatches.slice(1))) {
yield { ...match, ...matches1, ...matches2 };
}
}
};

console.log(
"eClassMatches!",
eClassMatches(vari("go", vari("1"), node(2)), eClassFromNode(a))
);

// Aside: the implicit lifting language could maybe really help simplify this.
Loading

0 comments on commit c746f2a

Please sign in to comment.