diff --git a/.github/workflows/swift-codegen.yml b/.github/workflows/swift-codegen.yml index d3e55ca75f19..3714b00d25ce 100644 --- a/.github/workflows/swift-codegen.yml +++ b/.github/workflows/swift-codegen.yml @@ -19,9 +19,14 @@ jobs: cache: 'pip' - uses: ./.github/actions/fetch-codeql - uses: bazelbuild/setup-bazelisk@v2 - - name: Check code generation + - name: Install dependencies run: | pip install -r swift/codegen/requirements.txt + - name: Run unit tests + run: | + bazel test //swift/codegen:tests --test_output=errors + - name: Check that code was generated + run: | bazel run //swift/codegen git add swift git diff --exit-code --stat HEAD diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d90a7982a572..5676f36512b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,3 +40,10 @@ repos: language: system entry: bazel run //swift/codegen pass_filenames: false + + - id: swift-codegen-unit-tests + name: Run Swift code generation unit tests + files: ^swift/codegen + language: system + entry: bazel test //swift/codegen:tests + pass_filenames: false diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000000..a118fa835ef1 --- /dev/null +++ b/conftest.py @@ -0,0 +1 @@ +# this empty file adds the repo root to PYTHON_PATH when running pytest diff --git a/swift/codegen/.coverage b/swift/codegen/.coverage new file mode 100644 index 000000000000..1f3d1b3104cf Binary files /dev/null and b/swift/codegen/.coverage differ diff --git a/swift/codegen/BUILD.bazel b/swift/codegen/BUILD.bazel index 25bcf30c2970..ae65641fd63b 100644 --- a/swift/codegen/BUILD.bazel +++ b/swift/codegen/BUILD.bazel @@ -1,4 +1,31 @@ py_binary( name = "codegen", - srcs = glob(["**/*.py"]), + srcs = glob([ + "lib/*.py", + "*.py", + ]), +) + +py_library( + name = "test_utils", + testonly = True, + srcs = ["test/utils.py"], + deps = [":codegen"], +) + +[ + py_test( + name = src[len("test/"):-len(".py")], + size = "small", + srcs = [src], + deps = [ + ":codegen", + ":test_utils", + ], + ) + for src in glob(["test/test_*.py"]) +] + +test_suite( + name = "tests", ) diff --git a/swift/codegen/dbschemegen.py b/swift/codegen/dbschemegen.py index eb3a90288580..79eb5983e74a 100755 --- a/swift/codegen/dbschemegen.py +++ b/swift/codegen/dbschemegen.py @@ -3,8 +3,8 @@ import inflection -from lib import paths, schema, generator -from lib.dbscheme import * +from swift.codegen.lib import paths, schema, generator +from swift.codegen.lib.dbscheme import * log = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def cls_to_dbscheme(cls: schema.Class): def get_declarations(data: schema.Schema): - return [d for cls in data.classes.values() for d in cls_to_dbscheme(cls)] + return [d for cls in data.classes for d in cls_to_dbscheme(cls)] def get_includes(data: schema.Schema, include_dir: pathlib.Path): @@ -73,11 +73,10 @@ def get_includes(data: schema.Schema, include_dir: pathlib.Path): def generate(opts, renderer): - input = opts.schema.resolve() - out = opts.dbscheme.resolve() + input = opts.schema + out = opts.dbscheme - with open(input) as src: - data = schema.load(src) + data = schema.load(input) dbscheme = DbScheme(src=input.relative_to(paths.swift_dir), includes=get_includes(data, include_dir=input.parent), diff --git a/swift/codegen/lib/options.py b/swift/codegen/lib/options.py index 373564b65489..df71192c65eb 100644 --- a/swift/codegen/lib/options.py +++ b/swift/codegen/lib/options.py @@ -10,13 +10,17 @@ def _init_options(): Option("--verbose", "-v", action="store_true") - Option("--schema", tags=["schema"], type=pathlib.Path, default=paths.swift_dir / "codegen/schema.yml") - Option("--dbscheme", tags=["dbscheme"], type=pathlib.Path, default=paths.swift_dir / "ql/lib/swift.dbscheme") - Option("--ql-output", tags=["ql"], type=pathlib.Path, default=paths.swift_dir / "ql/lib/codeql/swift/generated") - Option("--ql-stub-output", tags=["ql"], type=pathlib.Path, default=paths.swift_dir / "ql/lib/codeql/swift/elements") + Option("--schema", tags=["schema"], type=_abspath, default=paths.swift_dir / "codegen/schema.yml") + Option("--dbscheme", tags=["dbscheme"], type=_abspath, default=paths.swift_dir / "ql/lib/swift.dbscheme") + Option("--ql-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/generated") + Option("--ql-stub-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/elements") Option("--codeql-binary", tags=["ql"], default="codeql") +def _abspath(x): + return pathlib.Path(x).resolve() + + _options = collections.defaultdict(list) diff --git a/swift/codegen/lib/paths.py b/swift/codegen/lib/paths.py index e458a49f6569..e202b17674e7 100644 --- a/swift/codegen/lib/paths.py +++ b/swift/codegen/lib/paths.py @@ -5,13 +5,16 @@ import os try: - _workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']) # <- means we are using bazel run + _workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']).resolve() # <- means we are using bazel run swift_dir = _workspace_dir / 'swift' - lib_dir = swift_dir / 'codegen' / 'lib' except KeyError: _this_file = pathlib.Path(__file__).resolve() swift_dir = _this_file.parents[2] - lib_dir = _this_file.parent +lib_dir = swift_dir / 'codegen' / 'lib' +templates_dir = lib_dir / 'templates' -exe_file = pathlib.Path(sys.argv[0]).resolve() +try: + exe_file = pathlib.Path(sys.argv[0]).resolve().relative_to(swift_dir) +except ValueError: + exe_file = pathlib.Path(sys.argv[0]).name diff --git a/swift/codegen/lib/ql.py b/swift/codegen/lib/ql.py new file mode 100644 index 000000000000..f285eeca4236 --- /dev/null +++ b/swift/codegen/lib/ql.py @@ -0,0 +1,88 @@ +import pathlib +from dataclasses import dataclass, field +from typing import List, ClassVar + +import inflection + + +@dataclass +class QlParam: + param: str + type: str = None + first: bool = False + + +@dataclass +class QlProperty: + singular: str + type: str + tablename: str + tableparams: List[QlParam] + plural: str = None + params: List[QlParam] = field(default_factory=list) + first: bool = False + local_var: str = "x" + + def __post_init__(self): + if self.params: + self.params[0].first = True + while self.local_var in (p.param for p in self.params): + self.local_var += "_" + assert self.tableparams + if self.type_is_class: + self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams] + self.tableparams = [QlParam(x) for x in self.tableparams] + self.tableparams[0].first = True + + @property + def indefinite_article(self): + if self.plural: + return "An" if self.singular[0] in "AEIO" else "A" + + @property + def type_is_class(self): + return self.type[0].isupper() + + +@dataclass +class QlClass: + template: ClassVar = 'ql_class' + + name: str + bases: List[str] = field(default_factory=list) + final: bool = False + properties: List[QlProperty] = field(default_factory=list) + dir: pathlib.Path = pathlib.Path() + imports: List[str] = field(default_factory=list) + + def __post_init__(self): + self.bases = sorted(self.bases) + if self.properties: + self.properties[0].first = True + + @property + def db_id(self): + return "@" + inflection.underscore(self.name) + + @property + def root(self): + return not self.bases + + @property + def path(self): + return self.dir / self.name + + +@dataclass +class QlStub: + template: ClassVar = 'ql_stub' + + name: str + base_import: str + + +@dataclass +class QlImportList: + template: ClassVar = 'ql_imports' + + imports: List[str] = field(default_factory=list) diff --git a/swift/codegen/lib/render.py b/swift/codegen/lib/render.py index 1129f42b85a2..8fe066fe6643 100644 --- a/swift/codegen/lib/render.py +++ b/swift/codegen/lib/render.py @@ -19,8 +19,7 @@ class Renderer: """ Template renderer using mustache templates in the `templates` directory """ def __init__(self): - self.r = pystache.Renderer(search_dirs=str(paths.lib_dir / "templates"), escape=lambda u: u) - self.generator = paths.exe_file.relative_to(paths.swift_dir) + self._r = pystache.Renderer(search_dirs=str(paths.lib_dir / "templates"), escape=lambda u: u) self.written = set() def render(self, data, output: pathlib.Path): @@ -32,7 +31,7 @@ def render(self, data, output: pathlib.Path): """ mnemonic = type(data).__name__ output.parent.mkdir(parents=True, exist_ok=True) - data = self.r.render_name(data.template, data, generator=self.generator) + data = self._r.render_name(data.template, data, generator=paths.exe_file) with open(output, "w") as out: out.write(data) log.debug(f"generated {mnemonic} {output.name}") @@ -41,6 +40,5 @@ def render(self, data, output: pathlib.Path): def cleanup(self, existing): """ Remove files in `existing` for which no `render` has been called """ for f in existing - self.written: - if f.is_file(): - f.unlink() - log.info(f"removed {f.name}") + f.unlink(missing_ok=True) + log.info(f"removed {f.name}") diff --git a/swift/codegen/lib/schema.py b/swift/codegen/lib/schema.py index 4030676a174f..5871c90ed3bd 100644 --- a/swift/codegen/lib/schema.py +++ b/swift/codegen/lib/schema.py @@ -3,7 +3,6 @@ import pathlib import re from dataclasses import dataclass, field -from enum import Enum, auto from typing import List, Set, Dict, ClassVar import yaml @@ -47,7 +46,7 @@ class Class: @dataclass class Schema: - classes: Dict[str, Class] + classes: List[Class] includes: Set[str] = field(default_factory=set) @@ -65,6 +64,7 @@ def _parse_property(name, type): class _DirSelector: """ Default output subdirectory selector for generated QL files, based on the `_directories` global field""" + def __init__(self, dir_to_patterns): self.selector = [(re.compile(p), pathlib.Path(d)) for d, p in dir_to_patterns] self.selector.append((re.compile(""), pathlib.Path())) @@ -73,19 +73,19 @@ def get(self, name): return next(d for p, d in self.selector if p.search(name)) -def load(file): - """ Parse the schema from `file` """ - data = yaml.load(file, Loader=yaml.SafeLoader) +def load(path): + """ Parse the schema from the file at `path` """ + with open(path) as input: + data = yaml.load(input, Loader=yaml.SafeLoader) grouper = _DirSelector(data.get("_directories", {}).items()) - ret = Schema(classes={cls: Class(cls, dir=grouper.get(cls)) for cls in data if not cls.startswith("_")}, - includes=set(data.get("_includes", []))) - assert root_class_name not in ret.classes - ret.classes[root_class_name] = Class(root_class_name) + classes = {root_class_name: Class(root_class_name)} + assert root_class_name not in data + classes.update((cls, Class(cls, dir=grouper.get(cls))) for cls in data if not cls.startswith("_")) for name, info in data.items(): if name.startswith("_"): continue assert name[0].isupper() - cls = ret.classes[name] + cls = classes[name] for k, v in info.items(): if not k.startswith("_"): cls.properties.append(_parse_property(k, v)) @@ -94,11 +94,11 @@ def load(file): v = [v] for base in v: cls.bases.add(base) - ret.classes[base].derived.add(name) + classes[base].derived.add(name) elif k == "_dir": cls.dir = pathlib.Path(v) if not cls.bases: cls.bases.add(root_class_name) - ret.classes[root_class_name].derived.add(name) + classes[root_class_name].derived.add(name) - return ret + return Schema(classes=list(classes.values()), includes=set(data.get("_includes", []))) diff --git a/swift/codegen/qlgen.py b/swift/codegen/qlgen.py index 0ea6d60efa38..4aaceb5fd14b 100755 --- a/swift/codegen/qlgen.py +++ b/swift/codegen/qlgen.py @@ -1,129 +1,43 @@ #!/usr/bin/env python3 import logging -import pathlib import subprocess -from dataclasses import dataclass, field -from typing import List, ClassVar import inflection -from lib import schema, paths, generator +from swift.codegen.lib import schema, paths, generator, ql log = logging.getLogger(__name__) -@dataclass -class QlParam: - param: str - type: str = None - first: bool = False - - -@dataclass -class QlProperty: - singular: str - type: str - tablename: str - tableparams: List[QlParam] - plural: str = None - params: List[QlParam] = field(default_factory=list) - first: bool = False - local_var: str = "x" - - def __post_init__(self): - if self.params: - self.params[0].first = True - while self.local_var in (p.param for p in self.params): - self.local_var += "_" - assert self.tableparams - if self.type_is_class: - self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams] - self.tableparams = [QlParam(x) for x in self.tableparams] - self.tableparams[0].first = True - - @property - def indefinite_article(self): - if self.plural: - return "An" if self.singular[0] in "AEIO" else "A" - - @property - def type_is_class(self): - return self.type[0].isupper() - - -@dataclass -class QlClass: - template: ClassVar = 'ql_class' - - name: str - bases: List[str] - final: bool - properties: List[QlProperty] - dir: pathlib.Path - imports: List[str] = field(default_factory=list) - - def __post_init__(self): - self.bases = sorted(self.bases) - if self.properties: - self.properties[0].first = True - - @property - def db_id(self): - return "@" + inflection.underscore(self.name) - - @property - def root(self): - return not self.bases - - @property - def path(self): - return self.dir / self.name - - -@dataclass -class QlStub: - template: ClassVar = 'ql_stub' - - name: str - base_import: str - - -@dataclass -class QlImportList: - template: ClassVar = 'ql_imports' - - imports: List[str] = field(default_factory=list) - - def get_ql_property(cls: schema.Class, prop: schema.Property): if prop.is_single: - return QlProperty( + return ql.QlProperty( singular=inflection.camelize(prop.name), type=prop.type, tablename=inflection.tableize(cls.name), tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single], ) elif prop.is_optional: - return QlProperty( + return ql.QlProperty( singular=inflection.camelize(prop.name), type=prop.type, tablename=inflection.tableize(f"{cls.name}_{prop.name}"), tableparams=["this", "result"], ) elif prop.is_repeated: - return QlProperty( + return ql.QlProperty( singular=inflection.singularize(inflection.camelize(prop.name)), plural=inflection.pluralize(inflection.camelize(prop.name)), type=prop.type, tablename=inflection.tableize(f"{cls.name}_{prop.name}"), tableparams=["this", "index", "result"], - params=[QlParam("index", type="int")], + params=[ql.QlParam("index", type="int")], ) def get_ql_class(cls: schema.Class): - return QlClass( + return ql.QlClass( name=cls.name, bases=cls.bases, final=not cls.derived, @@ -137,7 +51,7 @@ def get_import(file): return str(stem).replace("/", ".") -def get_types_used_by(cls: QlClass): +def get_types_used_by(cls: ql.QlClass): for b in cls.bases: yield b for p in cls.properties: @@ -146,7 +60,7 @@ def get_types_used_by(cls: QlClass): yield param.type -def get_classes_used_by(cls: QlClass): +def get_classes_used_by(cls: ql.QlClass): return sorted(set(t for t in get_types_used_by(cls) if t[0].isupper())) @@ -164,34 +78,32 @@ def format(codeql, files): def generate(opts, renderer): - input = opts.schema.resolve() - out = opts.ql_output.resolve() - stub_out = opts.ql_stub_output.resolve() + input = opts.schema + out = opts.ql_output + stub_out = opts.ql_stub_output existing = {q for q in out.rglob("*.qll")} existing |= {q for q in stub_out.rglob("*.qll") if is_generated(q)} - with open(input) as src: - data = schema.load(src) + data = schema.load(input) - classes = [get_ql_class(cls) for cls in data.classes.values()] + classes = [get_ql_class(cls) for cls in data.classes] imports = {} for c in classes: imports[c.name] = get_import(stub_out / c.path) for c in classes: - assert not c.final or c.bases, c.name qll = (out / c.path).with_suffix(".qll") c.imports = [imports[t] for t in get_classes_used_by(c)] renderer.render(c, qll) stub_file = (stub_out / c.path).with_suffix(".qll") if not stub_file.is_file() or is_generated(stub_file): - stub = QlStub(name=c.name, base_import=get_import(qll)) + stub = ql.QlStub(name=c.name, base_import=get_import(qll)) renderer.render(stub, stub_file) # for example path/to/syntax/generated -> path/to/syntax.qll include_file = stub_out.with_suffix(".qll") - all_imports = QlImportList(v for _, v in sorted(imports.items())) + all_imports = ql.QlImportList([v for _, v in sorted(imports.items())]) renderer.render(all_imports, include_file) renderer.cleanup(existing) diff --git a/swift/codegen/requirements.txt b/swift/codegen/requirements.txt index df5110f65355..b8959a4b15df 100644 --- a/swift/codegen/requirements.txt +++ b/swift/codegen/requirements.txt @@ -1,3 +1,4 @@ pystache pyyaml inflection +pytest diff --git a/swift/codegen/test/test_dbschemegen.py b/swift/codegen/test/test_dbschemegen.py new file mode 100644 index 000000000000..3fb6f798ae49 --- /dev/null +++ b/swift/codegen/test/test_dbschemegen.py @@ -0,0 +1,328 @@ +import pathlib +import sys + +from swift.codegen import dbschemegen +from swift.codegen.lib import dbscheme, paths +from swift.codegen.test.utils import * + +def generate(opts, renderer): + (out, data), = run_generation(dbschemegen.generate, opts, renderer).items() + assert out is opts.dbscheme + return data + + +def test_empty(opts, input, renderer): + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[], + ) + + +def test_includes(opts, input, renderer): + includes = ["foo", "bar"] + input.includes = includes + for i in includes: + write(opts.schema.parent / i, i + " data") + + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[ + dbscheme.DbSchemeInclude( + src=schema_dir / i, + data=i + " data", + ) for i in includes + ], + declarations=[], + ) + + +def test_empty_final_class(opts, input, renderer): + input.classes = [ + schema.Class("Object"), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbTable( + name="objects", + columns=[ + dbscheme.DbColumn('id', '@object', binding=True), + ] + ) + ], + ) + + +def test_final_class_with_single_scalar_field(opts, input, renderer): + input.classes = [ + + schema.Class("Object", properties=[ + schema.SingleProperty("foo", "bar"), + ]), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbTable( + name="objects", + columns=[ + dbscheme.DbColumn('id', '@object', binding=True), + dbscheme.DbColumn('foo', 'bar'), + ] + ) + ], + ) + + +def test_final_class_with_single_class_field(opts, input, renderer): + input.classes = [ + schema.Class("Object", properties=[ + schema.SingleProperty("foo", "Bar"), + ]), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbTable( + name="objects", + columns=[ + dbscheme.DbColumn('id', '@object', binding=True), + dbscheme.DbColumn('foo', '@bar'), + ] + ) + ], + ) + + +def test_final_class_with_optional_field(opts, input, renderer): + input.classes = [ + schema.Class("Object", properties=[ + schema.OptionalProperty("foo", "bar"), + ]), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbTable( + name="objects", + columns=[ + dbscheme.DbColumn('id', '@object', binding=True), + ] + ), + dbscheme.DbTable( + name="object_foos", + keyset=dbscheme.DbKeySet(["id"]), + columns=[ + dbscheme.DbColumn('id', '@object'), + dbscheme.DbColumn('foo', 'bar'), + ] + ), + ], + ) + + +def test_final_class_with_repeated_field(opts, input, renderer): + input.classes = [ + schema.Class("Object", properties=[ + schema.RepeatedProperty("foo", "bar"), + ]), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbTable( + name="objects", + columns=[ + dbscheme.DbColumn('id', '@object', binding=True), + ] + ), + dbscheme.DbTable( + name="object_foos", + keyset=dbscheme.DbKeySet(["id", "index"]), + columns=[ + dbscheme.DbColumn('id', '@object'), + dbscheme.DbColumn('index', 'int'), + dbscheme.DbColumn('foo', 'bar'), + ] + ), + ], + ) + + +def test_final_class_with_more_fields(opts, input, renderer): + input.classes = [ + schema.Class("Object", properties=[ + schema.SingleProperty("one", "x"), + schema.SingleProperty("two", "y"), + schema.OptionalProperty("three", "z"), + schema.RepeatedProperty("four", "w"), + ]), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbTable( + name="objects", + columns=[ + dbscheme.DbColumn('id', '@object', binding=True), + dbscheme.DbColumn('one', 'x'), + dbscheme.DbColumn('two', 'y'), + ] + ), + dbscheme.DbTable( + name="object_threes", + keyset=dbscheme.DbKeySet(["id"]), + columns=[ + dbscheme.DbColumn('id', '@object'), + dbscheme.DbColumn('three', 'z'), + ] + ), + dbscheme.DbTable( + name="object_fours", + keyset=dbscheme.DbKeySet(["id", "index"]), + columns=[ + dbscheme.DbColumn('id', '@object'), + dbscheme.DbColumn('index', 'int'), + dbscheme.DbColumn('four', 'w'), + ] + ), + ], + ) + + +def test_empty_class_with_derived(opts, input, renderer): + input.classes = [ + schema.Class( + name="Base", + derived={"Left", "Right"}), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbUnion( + lhs="@base", + rhs=["@left", "@right"], + ), + ], + ) + + +def test_class_with_derived_and_single_property(opts, input, renderer): + input.classes = [ + schema.Class( + name="Base", + derived={"Left", "Right"}, + properties=[ + schema.SingleProperty("single", "Prop"), + ]), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbUnion( + lhs="@base", + rhs=["@left", "@right"], + ), + dbscheme.DbTable( + name="bases", + keyset=dbscheme.DbKeySet(["id"]), + columns=[ + dbscheme.DbColumn('id', '@base'), + dbscheme.DbColumn('single', '@prop'), + ] + ) + ], + ) + + +def test_class_with_derived_and_optional_property(opts, input, renderer): + input.classes = [ + schema.Class( + name="Base", + derived={"Left", "Right"}, + properties=[ + schema.OptionalProperty("opt", "Prop"), + ]), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbUnion( + lhs="@base", + rhs=["@left", "@right"], + ), + dbscheme.DbTable( + name="base_opts", + keyset=dbscheme.DbKeySet(["id"]), + columns=[ + dbscheme.DbColumn('id', '@base'), + dbscheme.DbColumn('opt', '@prop'), + ] + ) + ], + ) + + +def test_class_with_derived_and_repeated_property(opts, input, renderer): + input.classes = [ + schema.Class( + name="Base", + derived={"Left", "Right"}, + properties=[ + schema.RepeatedProperty("rep", "Prop"), + ]), + ] + assert generate(opts, renderer) == dbscheme.DbScheme( + src=schema_file, + includes=[], + declarations=[ + dbscheme.DbUnion( + lhs="@base", + rhs=["@left", "@right"], + ), + dbscheme.DbTable( + name="base_reps", + keyset=dbscheme.DbKeySet(["id", "index"]), + columns=[ + dbscheme.DbColumn('id', '@base'), + dbscheme.DbColumn('index', 'int'), + dbscheme.DbColumn('rep', '@prop'), + ] + ) + ], + ) + + +def test_dbcolumn_name(): + assert dbscheme.DbColumn("foo", "some_type").name == "foo" + + +@pytest.mark.parametrize("keyword", dbscheme.dbscheme_keywords) +def test_dbcolumn_keyword_name(keyword): + assert dbscheme.DbColumn(keyword, "some_type").name == keyword + "_" + + +@pytest.mark.parametrize("type,binding,lhstype,rhstype", [ + ("builtin_type", False, "builtin_type", "builtin_type ref"), + ("builtin_type", True, "builtin_type", "builtin_type ref"), + ("@at_type", False, "int", "@at_type ref"), + ("@at_type", True, "unique int", "@at_type"), +]) +def test_dbcolumn_types(type, binding, lhstype, rhstype): + col = dbscheme.DbColumn("foo", type, binding) + assert col.lhstype == lhstype + assert col.rhstype == rhstype + + +if __name__ == '__main__': + sys.exit(pytest.main()) diff --git a/swift/codegen/test/test_qlgen.py b/swift/codegen/test/test_qlgen.py new file mode 100644 index 000000000000..9379ce317868 --- /dev/null +++ b/swift/codegen/test/test_qlgen.py @@ -0,0 +1,199 @@ +import subprocess +import sys + +import mock + +from swift.codegen import qlgen +from swift.codegen.lib import ql, paths +from swift.codegen.test.utils import * + + +@pytest.fixture(autouse=True) +def run_mock(): + with mock.patch("subprocess.run") as ret: + yield ret + + +stub_path = lambda: paths.swift_dir / "ql/lib/stub/path" +ql_output_path = lambda: paths.swift_dir / "ql/lib/other/path" +import_file = lambda: stub_path().with_suffix(".qll") +stub_import_prefix = "stub.path." +gen_import_prefix = "other.path." +index_param = ql.QlParam("index", "int") + + +def generate(opts, renderer, written=None): + opts.ql_stub_output = stub_path() + opts.ql_output = ql_output_path() + renderer.written = written or [] + return run_generation(qlgen.generate, opts, renderer) + + +def test_empty(opts, input, renderer): + assert generate(opts, renderer) == { + import_file(): ql.QlImportList() + } + + +def test_one_empty_class(opts, input, renderer): + input.classes = [ + schema.Class("A") + ] + assert generate(opts, renderer) == { + import_file(): ql.QlImportList([stub_import_prefix + "A"]), + stub_path() / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "A"), + ql_output_path() / "A.qll": ql.QlClass(name="A", final=True), + } + + +def test_hierarchy(opts, input, renderer): + input.classes = [ + schema.Class("D", bases={"B", "C"}), + schema.Class("C", bases={"A"}, derived={"D"}), + schema.Class("B", bases={"A"}, derived={"D"}), + schema.Class("A", derived={"B", "C"}), + ] + assert generate(opts, renderer) == { + import_file(): ql.QlImportList([stub_import_prefix + cls for cls in "ABCD"]), + stub_path() / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "A"), + stub_path() / "B.qll": ql.QlStub(name="B", base_import=gen_import_prefix + "B"), + stub_path() / "C.qll": ql.QlStub(name="C", base_import=gen_import_prefix + "C"), + stub_path() / "D.qll": ql.QlStub(name="D", base_import=gen_import_prefix + "D"), + ql_output_path() / "A.qll": ql.QlClass(name="A"), + ql_output_path() / "B.qll": ql.QlClass(name="B", bases=["A"], imports=[stub_import_prefix + "A"]), + ql_output_path() / "C.qll": ql.QlClass(name="C", bases=["A"], imports=[stub_import_prefix + "A"]), + ql_output_path() / "D.qll": ql.QlClass(name="D", final=True, bases=["B", "C"], + imports=[stub_import_prefix + cls for cls in "BC"]), + + } + + +def test_single_property(opts, input, renderer): + input.classes = [ + schema.Class("MyObject", properties=[schema.SingleProperty("foo", "bar")]), + ] + assert generate(opts, renderer) == { + import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]), + stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), + ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[ + ql.QlProperty(singular="Foo", type="bar", tablename="my_objects", tableparams=["this", "result"]), + ]) + } + + +def test_single_properties(opts, input, renderer): + input.classes = [ + schema.Class("MyObject", properties=[ + schema.SingleProperty("one", "x"), + schema.SingleProperty("two", "y"), + schema.SingleProperty("three", "z"), + ]), + ] + assert generate(opts, renderer) == { + import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]), + stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), + ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[ + ql.QlProperty(singular="One", type="x", tablename="my_objects", tableparams=["this", "result", "_", "_"]), + ql.QlProperty(singular="Two", type="y", tablename="my_objects", tableparams=["this", "_", "result", "_"]), + ql.QlProperty(singular="Three", type="z", tablename="my_objects", tableparams=["this", "_", "_", "result"]), + ]) + } + + +def test_optional_property(opts, input, renderer): + input.classes = [ + schema.Class("MyObject", properties=[schema.OptionalProperty("foo", "bar")]), + ] + assert generate(opts, renderer) == { + import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]), + stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), + ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[ + ql.QlProperty(singular="Foo", type="bar", tablename="my_object_foos", tableparams=["this", "result"]), + ]) + } + + +def test_repeated_property(opts, input, renderer): + input.classes = [ + schema.Class("MyObject", properties=[schema.RepeatedProperty("foo", "bar")]), + ] + assert generate(opts, renderer) == { + import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]), + stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), + ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[ + ql.QlProperty(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", params=[index_param], + tableparams=["this", "index", "result"]), + ]) + } + + +def test_single_class_property(opts, input, renderer): + input.classes = [ + schema.Class("MyObject", properties=[schema.SingleProperty("foo", "Bar")]), + schema.Class("Bar"), + ] + assert generate(opts, renderer) == { + import_file(): ql.QlImportList([stub_import_prefix + cls for cls in ("Bar", "MyObject")]), + stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), + stub_path() / "Bar.qll": ql.QlStub(name="Bar", base_import=gen_import_prefix + "Bar"), + ql_output_path() / "MyObject.qll": ql.QlClass( + name="MyObject", final=True, imports=[stub_import_prefix + "Bar"], properties=[ + ql.QlProperty(singular="Foo", type="Bar", tablename="my_objects", tableparams=["this", "result"]), + ], + ), + ql_output_path() / "Bar.qll": ql.QlClass(name="Bar", final=True) + } + + +def test_class_dir(opts, input, renderer): + dir = pathlib.Path("another/rel/path") + input.classes = [ + schema.Class("A", derived={"B"}, dir=dir), + schema.Class("B", bases={"A"}), + ] + assert generate(opts, renderer) == { + import_file(): ql.QlImportList([ + stub_import_prefix + "another.rel.path.A", + stub_import_prefix + "B", + ]), + stub_path() / dir / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "another.rel.path.A"), + stub_path() / "B.qll": ql.QlStub(name="B", base_import=gen_import_prefix + "B"), + ql_output_path() / dir / "A.qll": ql.QlClass(name="A", dir=dir), + ql_output_path() / "B.qll": ql.QlClass(name="B", final=True, bases=["A"], + imports=[stub_import_prefix + "another.rel.path.A"]) + } + + +def test_format(opts, input, renderer, run_mock): + opts.codeql_binary = "my_fake_codeql" + run_mock.return_value.stderr = "some\nlines\n" + generate(opts, renderer, written=["foo", "bar"]) + assert run_mock.mock_calls == [ + mock.call(["my_fake_codeql", "query", "format", "--in-place", "--", "foo", "bar"], + check=True, stderr=subprocess.PIPE, text=True), + ] + + +def test_empty_cleanup(opts, input, renderer): + generate(opts, renderer) + assert renderer.mock_calls[-1] == mock.call.cleanup(set()) + + +def test_empty_cleanup(opts, input, renderer, tmp_path): + opts.ql_output = tmp_path / "gen" + opts.ql_stub_output = tmp_path / "stub" + renderer.written = [] + ql_a = opts.ql_output / "A.qll" + ql_b = opts.ql_output / "B.qll" + stub_a = opts.ql_stub_output / "A.qll" + stub_b = opts.ql_stub_output / "B.qll" + write(ql_a) + write(ql_b) + write(stub_a, "// generated\nfoo\n") + write(stub_b, "bar\n") + run_generation(qlgen.generate, opts, renderer) + assert renderer.mock_calls[-1] == mock.call.cleanup({ql_a, ql_b, stub_a}) + + +if __name__ == '__main__': + sys.exit(pytest.main()) diff --git a/swift/codegen/test/test_render.py b/swift/codegen/test/test_render.py new file mode 100644 index 000000000000..3f12845df0bb --- /dev/null +++ b/swift/codegen/test/test_render.py @@ -0,0 +1,79 @@ +import sys +from unittest import mock + +import pytest + +from swift.codegen.lib import paths +from swift.codegen.lib import render + + +@pytest.fixture +def pystache_renderer_cls(): + with mock.patch("pystache.Renderer") as ret: + yield ret + + +@pytest.fixture +def pystache_renderer(pystache_renderer_cls): + ret = mock.Mock() + pystache_renderer_cls.side_effect = (ret,) + return ret + + +@pytest.fixture +def sut(pystache_renderer): + return render.Renderer() + + +def test_constructor(pystache_renderer_cls, sut): + pystache_init, = pystache_renderer_cls.mock_calls + assert set(pystache_init.kwargs) == {'search_dirs', 'escape'} + assert pystache_init.kwargs['search_dirs'] == str(paths.templates_dir) + an_object = object() + assert pystache_init.kwargs['escape'](an_object) is an_object + assert sut.written == set() + + +def test_render(pystache_renderer, sut): + data = mock.Mock() + output = mock.Mock() + with mock.patch("builtins.open", mock.mock_open()) as output_stream: + sut.render(data, output) + assert pystache_renderer.mock_calls == [ + mock.call.render_name(data.template, data, generator=paths.exe_file), + ], pystache_renderer.mock_calls + assert output_stream.mock_calls == [ + mock.call(output, 'w'), + mock.call().__enter__(), + mock.call().write(pystache_renderer.render_name.return_value), + mock.call().__exit__(None, None, None), + ] + assert sut.written == {output} + + +def test_written(sut): + data = [mock.Mock() for _ in range(4)] + output = [mock.Mock() for _ in data] + with mock.patch("builtins.open", mock.mock_open()) as output_stream: + for d, o in zip(data, output): + sut.render(d, o) + assert sut.written == set(output) + + +def test_cleanup(sut): + data = [mock.Mock() for _ in range(4)] + output = [mock.Mock() for _ in data] + with mock.patch("builtins.open", mock.mock_open()) as output_stream: + for d, o in zip(data, output): + sut.render(d, o) + expected_erased = [mock.Mock() for _ in range(3)] + existing = set(expected_erased + output[2:]) + sut.cleanup(existing) + for f in expected_erased: + assert f.mock_calls == [mock.call.unlink(missing_ok=True)] + for f in output: + assert f.unlink.mock_calls == [] + + +if __name__ == '__main__': + sys.exit(pytest.main()) diff --git a/swift/codegen/test/test_schema.py b/swift/codegen/test/test_schema.py new file mode 100644 index 000000000000..9703eafe1b8c --- /dev/null +++ b/swift/codegen/test/test_schema.py @@ -0,0 +1,158 @@ +import io +import pathlib +import sys + +import mock +import pytest + +import swift.codegen.lib.schema as schema +from swift.codegen.test.utils import * + +root_name = schema.root_class_name + +@pytest.fixture +def load(tmp_path): + file = tmp_path / "schema.yml" + def ret(yml): + write(file, yml) + return schema.load(file) + + return ret + +def test_empty_schema(load): + ret = load("{}") + assert ret.classes == [schema.Class(root_name)] + assert ret.includes == set() + + +def test_one_empty_class(load): + ret = load(""" +MyClass: {} +""") + assert ret.classes == [ + schema.Class(root_name, derived={'MyClass'}), + schema.Class('MyClass', bases={root_name}), + ] + + +def test_two_empty_classes(load): + ret = load(""" +MyClass1: {} +MyClass2: {} +""") + assert ret.classes == [ + schema.Class(root_name, derived={'MyClass1', 'MyClass2'}), + schema.Class('MyClass1', bases={root_name}), + schema.Class('MyClass2', bases={root_name}), + ] + + +def test_two_empty_chained_classes(load): + ret = load(""" +MyClass1: {} +MyClass2: + _extends: MyClass1 +""") + assert ret.classes == [ + schema.Class(root_name, derived={'MyClass1'}), + schema.Class('MyClass1', bases={root_name}, derived={'MyClass2'}), + schema.Class('MyClass2', bases={'MyClass1'}), + ] + + +def test_empty_classes_diamond(load): + ret = load(""" +A: {} +B: {} +C: + _extends: + - A + - B +""") + assert ret.classes == [ + schema.Class(root_name, derived={'A', 'B'}), + schema.Class('A', bases={root_name}, derived={'C'}), + schema.Class('B', bases={root_name}, derived={'C'}), + schema.Class('C', bases={'A', 'B'}), + ] + + +def test_dir(load): + ret = load(""" +A: + _dir: other/dir +""") + assert ret.classes == [ + schema.Class(root_name, derived={'A'}), + schema.Class('A', bases={root_name}, dir=pathlib.Path("other/dir")), + ] + + +def test_directory_filter(load): + ret = load(""" +_directories: + first/dir: '[xy]' + second/dir: foo$ + third/dir: bar$ +Afoo: {} +Bbar: {} +Abar: {} +Bfoo: {} +Ax: {} +Ay: {} +A: {} +""") + assert ret.classes == [ + schema.Class(root_name, derived={'Afoo', 'Bbar', 'Abar', 'Bfoo', 'Ax', 'Ay', 'A'}), + schema.Class('Afoo', bases={root_name}, dir=pathlib.Path("second/dir")), + schema.Class('Bbar', bases={root_name}, dir=pathlib.Path("third/dir")), + schema.Class('Abar', bases={root_name}, dir=pathlib.Path("third/dir")), + schema.Class('Bfoo', bases={root_name}, dir=pathlib.Path("second/dir")), + schema.Class('Ax', bases={root_name}, dir=pathlib.Path("first/dir")), + schema.Class('Ay', bases={root_name}, dir=pathlib.Path("first/dir")), + schema.Class('A', bases={root_name}, dir=pathlib.Path()), + ] + + +def test_directory_filter_override(load): + ret = load(""" +_directories: + one/dir: ^A$ +A: + _dir: other/dir +""") + assert ret.classes == [ + schema.Class(root_name, derived={'A'}), + schema.Class('A', bases={root_name}, dir=pathlib.Path("other/dir")), + ] + + +def test_lowercase_rejected(load): + with pytest.raises(AssertionError): + load("aLowercase: {}") + + +def test_digit_rejected(load): + with pytest.raises(AssertionError): + load("1digit: {}") + + +def test_properties(load): + ret = load(""" +A: + one: string + two: int? + three: bool* +""") + assert ret.classes == [ + schema.Class(root_name, derived={'A'}), + schema.Class('A', bases={root_name}, properties=[ + schema.SingleProperty('one', 'string'), + schema.OptionalProperty('two', 'int'), + schema.RepeatedProperty('three', 'bool'), + ]), + ] + + +if __name__ == '__main__': + sys.exit(pytest.main()) diff --git a/swift/codegen/test/utils.py b/swift/codegen/test/utils.py new file mode 100644 index 000000000000..d4b40a9e155e --- /dev/null +++ b/swift/codegen/test/utils.py @@ -0,0 +1,50 @@ +import pathlib +from unittest import mock + +import pytest + +from swift.codegen.lib import render, schema + +schema_dir = pathlib.Path("a", "dir") +schema_file = schema_dir / "schema.yml" + + +def write(out, contents=""): + out.parent.mkdir(parents=True, exist_ok=True) + with open(out, "w") as out: + out.write(contents) + + +@pytest.fixture +def renderer(): + return mock.Mock(spec=render.Renderer()) + + +@pytest.fixture +def opts(): + return mock.MagicMock() + + +@pytest.fixture(autouse=True) +def override_paths(tmp_path): + with mock.patch("swift.codegen.lib.paths.swift_dir", tmp_path): + yield + + +@pytest.fixture +def input(opts, tmp_path): + opts.schema = tmp_path / schema_file + with mock.patch("swift.codegen.lib.schema.load") as load_mock: + load_mock.return_value = schema.Schema([]) + yield load_mock.return_value + assert load_mock.mock_calls == [ + mock.call(opts.schema) + ], load_mock.mock_calls + + +def run_generation(generate, opts, renderer): + output = {} + + renderer.render.side_effect = lambda data, out: output.__setitem__(out, data) + generate(opts, renderer) + return output diff --git a/swift/ql/lib/swift.dbscheme b/swift/ql/lib/swift.dbscheme index 6c0dc88a8be7..0aff926e6c0e 100644 --- a/swift/ql/lib/swift.dbscheme +++ b/swift/ql/lib/swift.dbscheme @@ -15,6 +15,16 @@ answer_to_life_the_universe_and_everything( // from codegen/schema.yml +@element = + @argument +| @file +| @generic_context +| @iterable_decl_context +| @locatable +| @location +| @type +; + files( unique int id: @file, string name: string ref @@ -1886,13 +1896,3 @@ integer_literal_exprs( unique int id: @integer_literal_expr, string string_value: string ref ); - -@element = - @argument -| @file -| @generic_context -| @iterable_decl_context -| @locatable -| @location -| @type -;