From 2d83ee818beb14c2b9acabbab862c021cec310a8 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 6 Jan 2025 01:03:27 -0500 Subject: [PATCH] refactor common nb impls --- build_tools/cmake/llvm_cache.cmake | 21 +- projects/CMakeLists.txt | 2 + .../src => common/eudsl}/bind_vec_like.h | 20 +- .../src => common/eudsl}/type_casters.h | 10 +- projects/common/eudsl/util.h | 103 ++++++ projects/eudsl-nbgen/CMakeLists.txt | 2 + .../cmake/eudsl_nbgen-config.cmake | 15 + projects/eudsl-nbgen/src/eudsl-nbgen.cpp | 227 ++++++++---- projects/eudsl-py/src/eudslpy_ext.cpp | 273 ++++++-------- projects/eudsl-tblgen/CMakeLists.txt | 3 + projects/eudsl-tblgen/src/TGLexer.cpp | 3 - .../eudsl-tblgen/src/eudsl_tblgen_ext.cpp | 332 ++++++++---------- 12 files changed, 568 insertions(+), 443 deletions(-) rename projects/{eudsl-py/src => common/eudsl}/bind_vec_like.h (94%) rename projects/{eudsl-py/src => common/eudsl}/type_casters.h (89%) create mode 100644 projects/common/eudsl/util.h diff --git a/build_tools/cmake/llvm_cache.cmake b/build_tools/cmake/llvm_cache.cmake index 30542d78..e3d28693 100644 --- a/build_tools/cmake/llvm_cache.cmake +++ b/build_tools/cmake/llvm_cache.cmake @@ -75,26 +75,7 @@ set(LLVM_INSTALL_TOOLCHAIN_ONLY OFF CACHE BOOL "") set(LLVM_DISTRIBUTIONS MlirDevelopment CACHE STRING "") set(LLVM_MlirDevelopment_DISTRIBUTION_COMPONENTS - clangAPINotes - clangAST - clangASTMatchers - clangAnalysis - clangBasic - clangDriver - clangDriver - clangEdit - clangFormat - clangFrontend - clangLex - clangParse - clangRewrite - clangSema - clangSerialization - clangSupport - clangTooling - clangToolingCore - clangToolingInclusions - + clang-libraries clang-headers # triggers ClangConfig.cmake and etc clang-cmake-exports diff --git a/projects/CMakeLists.txt b/projects/CMakeLists.txt index e3b16175..16b70edc 100644 --- a/projects/CMakeLists.txt +++ b/projects/CMakeLists.txt @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Copyright (c) 2024. +include_directories(common) + if(NOT WIN32) add_subdirectory(eudsl-py) endif() diff --git a/projects/eudsl-py/src/bind_vec_like.h b/projects/common/eudsl/bind_vec_like.h similarity index 94% rename from projects/eudsl-py/src/bind_vec_like.h rename to projects/common/eudsl/bind_vec_like.h index 5a733f69..46f2e226 100644 --- a/projects/eudsl-py/src/bind_vec_like.h +++ b/projects/common/eudsl/bind_vec_like.h @@ -1,11 +1,12 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Copyright (c) 2024. +// Copyright (c) 2024-2025. #pragma once #include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/TypeName.h" #include #include @@ -13,7 +14,9 @@ #include #include #include +#include +namespace eudsl { struct _ArrayRef {}; struct _MutableArrayRef {}; struct _SmallVector {}; @@ -283,3 +286,18 @@ nanobind::class_ bind_iter_range(nanobind::handle scope, return cl; } + +inline void bind_array_ref_smallvector(nanobind::handle scope) { + scope.attr("T") = nanobind::type_var("T"); + arrayRef = + nanobind::class_<_ArrayRef>(scope, "ArrayRef", nanobind::is_generic(), + nanobind::sig("class ArrayRef[T]")); + mutableArrayRef = nanobind::class_<_MutableArrayRef>( + scope, "MutableArrayRef", nanobind::is_generic(), + nanobind::sig("class MutableArrayRef[T]")); + smallVector = nanobind::class_<_SmallVector>( + scope, "SmallVector", nanobind::is_generic(), + nanobind::sig("class SmallVector[T]")); +} + +} // namespace eudsl diff --git a/projects/eudsl-py/src/type_casters.h b/projects/common/eudsl/type_casters.h similarity index 89% rename from projects/eudsl-py/src/type_casters.h rename to projects/common/eudsl/type_casters.h index 66cc31d4..4b8e7d91 100644 --- a/projects/eudsl-py/src/type_casters.h +++ b/projects/common/eudsl/type_casters.h @@ -1,14 +1,20 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Copyright (c) 2024. +// Copyright (c) 2024-2025. #pragma once #include -#include +// ReSharper disable once CppUnusedIncludeDirective #include +#include +// ReSharper disable once CppUnusedIncludeDirective #include +// ReSharper disable once CppUnusedIncludeDirective +#include +// ReSharper disable once CppUnusedIncludeDirective +#include "eudsl/bind_vec_like.h" template <> struct nanobind::detail::type_caster { diff --git a/projects/common/eudsl/util.h b/projects/common/eudsl/util.h new file mode 100644 index 00000000..2b32dba2 --- /dev/null +++ b/projects/common/eudsl/util.h @@ -0,0 +1,103 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Copyright (c) 2025. + +#pragma once + +#include + +namespace eudsl { +template +struct non_copying_non_moving_class_ : nanobind::class_ { + template + NB_INLINE non_copying_non_moving_class_(nanobind::handle scope, + const char *name, + const Extra &...extra) { + nanobind::detail::type_init_data d; + + d.flags = 0; + d.align = (uint8_t)alignof(typename nanobind::class_::Alias); + d.size = (uint32_t)sizeof(typename nanobind::class_::Alias); + d.name = name; + d.scope = scope.ptr(); + d.type = &typeid(T); + + if constexpr (!std::is_same_v::Base, + T>) { + d.base = &typeid(typename nanobind::class_::Base); + d.flags |= (uint32_t)nanobind::detail::type_init_flags::has_base; + } + + if constexpr (std::is_destructible_v) { + d.flags |= (uint32_t)nanobind::detail::type_flags::is_destructible; + + if constexpr (!std::is_trivially_destructible_v) { + d.flags |= (uint32_t)nanobind::detail::type_flags::has_destruct; + d.destruct = nanobind::detail::wrap_destruct; + } + } + + if constexpr (nanobind::detail::has_shared_from_this_v) { + d.flags |= (uint32_t)nanobind::detail::type_flags::has_shared_from_this; + d.keep_shared_from_this_alive = [](PyObject *self) noexcept { + if (auto sp = nanobind::inst_ptr(self)->weak_from_this().lock()) { + nanobind::detail::keep_alive( + self, new auto(std::move(sp)), + [](void *p) noexcept { delete (decltype(sp) *)p; }); + return true; + } + return false; + }; + } + + (nanobind::detail::type_extra_apply(d, extra), ...); + + this->m_ptr = nanobind::detail::nb_type_new(&d); + } +}; + +template +constexpr auto coerceReturn(Return (*pf)(Args...)) noexcept { + return [&pf](Args &&...args) -> NewReturn { + return pf(std::forward(args)...); + }; +} + +template +constexpr auto coerceReturn(Return (Class::*pmf)(Args...), + std::false_type = {}) noexcept { + return [&pmf](Class *cls, Args &&...args) -> NewReturn { + return (cls->*pmf)(std::forward(args)...); + }; +} + +/* + * If you get + * ``` + * Called object type 'void(MyClass::*)(vector&,int)' is not a function or + * function pointer + * ``` + * it's because you're calling a member function without + * passing the `this` pointer as the first arg + */ +template +constexpr auto coerceReturn(Return (Class::*pmf)(Args...) const, + std::true_type) noexcept { + // copy the *pmf, not capture by ref + return [pmf](const Class &cls, Args &&...args) -> NewReturn { + return (cls.*pmf)(std::forward(args)...); + }; +} + +inline size_t wrap(Py_ssize_t i, size_t n) { + if (i < 0) + i += (Py_ssize_t)n; + + if (i < 0 || (size_t)i >= n) + throw nanobind::index_error(); + + return (size_t)i; +} + +} // namespace eudsl diff --git a/projects/eudsl-nbgen/CMakeLists.txt b/projects/eudsl-nbgen/CMakeLists.txt index a63b9287..0e1d0201 100644 --- a/projects/eudsl-nbgen/CMakeLists.txt +++ b/projects/eudsl-nbgen/CMakeLists.txt @@ -31,6 +31,8 @@ if(EUDSL_NBGEN_STANDALONE_BUILD) include(AddLLVM) include(AddClang) include(HandleLLVMOptions) + + include_directories(${CMAKE_CURRENT_LIST_DIR}/../common) endif() include_directories(${LLVM_INCLUDE_DIRS}) diff --git a/projects/eudsl-nbgen/cmake/eudsl_nbgen-config.cmake b/projects/eudsl-nbgen/cmake/eudsl_nbgen-config.cmake index 2548f4e5..f24b0aee 100644 --- a/projects/eudsl-nbgen/cmake/eudsl_nbgen-config.cmake +++ b/projects/eudsl-nbgen/cmake/eudsl_nbgen-config.cmake @@ -5,6 +5,19 @@ # copy-pasta from AddMLIR.cmake/AddLLVM.cmake/TableGen.cmake +set(EUDSL_NBGEN_NANOBIND_OPTIONS + -Wno-cast-qual + -Wno-deprecated-literal-operator + -Wno-covered-switch-default + -Wno-nested-anon-types + -Wno-zero-length-array + -Wno-c++98-compat-extra-semi + -Wno-c++20-extensions + $<$:-fexceptions -frtti> + $<$:-fexceptions -frtti> + $<$:/EHsc /GR> +) + function(eudsl_nbgen target input_file) set(EUDSL_NBGEN_TARGET_DEFINITIONS ${input_file}) cmake_parse_arguments(ARG "" "" "LINK_LIBS;EXTRA_INCLUDES;NAMESPACES" ${ARGN}) @@ -89,6 +102,7 @@ function(eudsl_nbgen target input_file) WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS ${EUDSL_NBGEN_EXE} ${global_tds} DEPFILE ${_depfile} + DEPENDS ${EUDSL_NBGEN_EXE} COMMENT "eudsl-nbgen: Generating ${_full_gen_file}..." ) # epic hack to specify all shards that will be generated even though we don't know them before hand @@ -137,6 +151,7 @@ function(eudsl_nbgen target input_file) endif() add_library(${target} STATIC "${_full_gen_file}.sharded.cpp" ${_shards}) + target_compile_options(${target} PUBLIC ${EUDSL_NBGEN_NANOBIND_OPTIONS}) execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --include_dir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_include_dir diff --git a/projects/eudsl-nbgen/src/eudsl-nbgen.cpp b/projects/eudsl-nbgen/src/eudsl-nbgen.cpp index de053aac..e0a3b4cd 100644 --- a/projects/eudsl-nbgen/src/eudsl-nbgen.cpp +++ b/projects/eudsl-nbgen/src/eudsl-nbgen.cpp @@ -93,6 +93,7 @@ static std::string getPyClassName(const std::string &qualifiedNameAsString) { s = std::regex_replace(s, std::regex(R"(\*)"), ""); s = std::regex_replace(s, std::regex("<"), "["); s = std::regex_replace(s, std::regex(">"), "]"); + s = std::regex_replace(s, std::regex("::"), "."); return s; } @@ -100,7 +101,6 @@ static std::string snakeCase(const std::string &name) { std::string s = name; s = std::regex_replace(s, std::regex(R"(([A-Z]+)([A-Z][a-z]))"), "$1_$2"); s = std::regex_replace(s, std::regex(R"(([a-z\d])([A-Z]))"), "$1_$2"); - s = std::regex_replace(s, std::regex("-"), "_"); std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); return s; @@ -143,20 +143,26 @@ static llvm::SmallPtrSet findOverloads(clang::FunctionDecl *decl, return results; } -// TODO(max): split this into two functions (one for names and one for types) -static std::string sanitizeNameOrType(std::string nameOrType, - int emptyIdx = 0) { - if (nameOrType == "from") - nameOrType = "from_"; - else if (nameOrType == "except") - nameOrType = "except_"; - else if (nameOrType == "") - nameOrType = std::string(emptyIdx + 1, '_'); - else if (nameOrType.rfind("ArrayRef", 0) == 0) - nameOrType = "llvm::" + nameOrType; - if (std::regex_search(nameOrType, std::regex(R"(std::__1)"))) - nameOrType = std::regex_replace(nameOrType, std::regex("std::__1"), "std"); - return nameOrType; +static std::string sanitizeName(std::string name, int emptyIdx = 0) { + if (name == "def") + name = "def_"; + if (name == "from") + name = "from_"; + else if (name == "except") + name = "except_"; + else if (name == "") + name = std::string(emptyIdx + 1, '_'); + return name; +} + +static std::string sanitizeType(std::string type) { + if (type.rfind("ArrayRef", 0) == 0) + type = "llvm::" + type; + if (std::regex_search(type, std::regex(R"(std::__1)"))) + type = std::regex_replace(type, std::regex("std::__1"), "std"); + if (std::regex_search(type, std::regex(R"(std::__cxx11)"))) + type = std::regex_replace(type, std::regex("std::__cxx11"), "std"); + return type; } // emit a lambda body to disambiguate/break ties amongst overloads @@ -187,12 +193,21 @@ std::string emitNBLambdaBody(clang::FunctionDecl *decl, n = llvm::formatv("std::move({0})", n); } std::string newParamNamesStr = llvm::join(newParamNames, ", "); + std::string return_; + auto returnTypeT = decl->getReturnType(); + if (!returnTypeT->isVoidType()) + return_ = "return"; + + bool canonical = true; + if (std::regex_search(returnTypeT.getAsString(), std::regex(R"(_t\b)"))) + canonical = false; + std::string returnType = + sanitizeType(returnTypeT.getAsString(getPrintingPolicy(canonical))); + std::string funcRef; - std::string returnType = sanitizeNameOrType( - decl->getReturnType().getAsString(getPrintingPolicy())); if (decl->isStatic() || !decl->isCXXClassMember()) { - funcRef = llvm::formatv("\n []({0}) -> {1} {{\n return {2}({3});\n }", - typedParamsStr, returnType, + funcRef = llvm::formatv("\n []({0}) -> {1} {{\n {2} {3}({4});\n }", + typedParamsStr, returnType, return_, decl->getQualifiedNameAsString(), newParamNamesStr); } else { assert(decl->isCXXClassMember() && "expected class member"); @@ -200,16 +215,76 @@ std::string emitNBLambdaBody(clang::FunctionDecl *decl, typedParamsStr = llvm::formatv("self, {0}", typedParamsStr); else typedParamsStr = "self"; + std::string methName = decl->getNameAsString(); + if (llvm::isa(decl)) + methName = "operator " + returnType; const clang::CXXRecordDecl *parentRecord = llvm::cast(decl->getParent()); - funcRef = llvm::formatv( - "\n []({0}& {1}) -> {2} {{\n return self.{3}({4});\n }", - parentRecord->getQualifiedNameAsString(), typedParamsStr, returnType, - decl->getNameAsString(), newParamNamesStr); + funcRef = + llvm::formatv("\n []({0}& {1}) -> {2} {{\n {3} self.{4}({5});\n }", + parentRecord->getQualifiedNameAsString(), typedParamsStr, + returnType, return_, methName, newParamNamesStr); } return funcRef; } +static std::string +processOperator(const std::string &fnName, + const llvm::SmallVector &argNames) { + std::string newFnName = fnName; + if (fnName == "operator!=") { + newFnName = "__ne__"; + } else if (fnName == "operator==") { + newFnName = "__eq__"; + } else if (fnName == "operator-") { + newFnName = "__neg__"; + } else if (fnName == "operator[]") { + newFnName = "__getitem__"; + } else if (fnName == "operator<") { + newFnName = "__lt__"; + } else if (fnName == "operator<=") { + newFnName = "__le__"; + } else if (fnName == "operator>") { + newFnName = "__gt__"; + } else if (fnName == "operator>=") { + newFnName = "__ge__"; + } else if (fnName == "operator%") { + newFnName = "__mod__"; + } else if (fnName == "operator*") { + if (!argNames.empty()) { + newFnName = "__mul__"; + } else { + // operator* not supported + } + } else if (fnName == "operator+" && !argNames.empty()) { + newFnName = "__add__"; + } else if (fnName == "operator->") { + // operator-> not supported + } else if (fnName == "operator!") { + // operator! not supported + } else if (fnName == "operator<<") { + // operator<< not supported + } + + return newFnName; +} + +static std::string getCppClass(const clang::CXXRecordDecl *decl) { + std::string className; + if (const clang::ClassTemplateSpecializationDecl *t = + llvm::dyn_cast(decl)) { + // TODO(max): this emits unnecessary default template args, like + // mlir::detail::TypeIDResolver + // auto td = t->getTypeForDecl(); + className = t->getTypeForDecl()->getCanonicalTypeInternal().getAsString( + getPrintingPolicy()); + } else { + className = decl->getQualifiedNameAsString(); + } + + return sanitizeType(className); +} + static bool emitClassMethodOrFunction(clang::FunctionDecl *decl, clang::CompilerInstance &ci, @@ -227,8 +302,8 @@ emitClassMethodOrFunction(clang::FunctionDecl *decl, if (std::regex_search(t.getAsString(), std::regex(R"(_t\b)"))) canonical = false; std::string paramType = t.getAsString(getPrintingPolicy(canonical)); - paramTypes.push_back(sanitizeNameOrType(paramType)); - paramNames.push_back(sanitizeNameOrType(name, i)); + paramTypes.push_back(sanitizeType(paramType)); + paramNames.push_back(sanitizeName(name, i)); } llvm::SmallPtrSet funcOverloads = @@ -237,21 +312,20 @@ emitClassMethodOrFunction(clang::FunctionDecl *decl, findOverloads(decl, ci.getSema()); std::string funcRef, nbFnName; - if (auto ctor = llvm::dyn_cast(decl)) { + if (clang::CXXConstructorDecl *ctor = + llvm::dyn_cast(decl)) { if (ctor->isDeleted()) return false; funcRef = llvm::formatv("nb::init<{0}>()", llvm::join(paramTypes, ", ")); } else { - if (funcOverloads.size() == 1 && funcTemplOverloads.empty()) { + if (funcOverloads.size() == 1 && funcTemplOverloads.empty()) funcRef = llvm::formatv("&{0}", decl->getQualifiedNameAsString()); - } else { + else funcRef = emitNBLambdaBody(decl, paramNames, paramTypes); - } - nbFnName = snakeCase(decl->getNameAsString()); - if (decl->isOverloadedOperator()) { - // TODO(max): handle overloaded operators - // nbFnName = nbFnName; + if (decl->isOverloadedOperator() || + llvm::isa(decl)) { + nbFnName = processOperator(nbFnName, paramNames); } else if (decl->isStatic() && funcOverloads.size() > 1 && llvm::any_of(funcOverloads, [](clang::FunctionDecl *m) { return !m->isStatic(); @@ -299,7 +373,7 @@ emitClassMethodOrFunction(clang::FunctionDecl *decl, if (decl->isCXXClassMember()) { const clang::CXXRecordDecl *parentRecord = llvm::cast(decl->getParent()); - scope = getNBBindClassName(parentRecord->getQualifiedNameAsString()); + scope = getNBBindClassName(getCppClass(parentRecord)); } outputFile->os() << llvm::formatv("{0}.{1}({2}{3}{4}{5}{6});\n", scope, @@ -309,13 +383,12 @@ emitClassMethodOrFunction(clang::FunctionDecl *decl, return true; } -std::string getNBScope(clang::TagDecl *decl) { +static std::string getNBScope(clang::TagDecl *decl) { std::string scope = "m"; const clang::DeclContext *declContext = decl->getDeclContext(); - if (declContext->isRecord()) { - const clang::CXXRecordDecl *ctx = - llvm::cast(declContext); - scope = getNBBindClassName(ctx->getQualifiedNameAsString()); + if (const clang::CXXRecordDecl *ctx = + llvm::dyn_cast(declContext)) { + scope = getNBBindClassName(getCppClass(ctx)); } return scope; } @@ -330,9 +403,9 @@ static bool emitClass(clang::CXXRecordDecl *decl, clang::CompilerInstance &ci, return false; } - std::string scope = getNBScope(decl); std::string additional = ""; - std::string className = decl->getQualifiedNameAsString(); + std::string cppClass = getCppClass(decl); + std::string autoVar = llvm::formatv("auto {0}", getNBBindClassName(cppClass)); if (decl->getNumBases() > 1) { clang::DiagnosticBuilder builder = ci.getDiagnostics().Report( decl->getLocation(), ci.getDiagnostics().getCustomDiagID( @@ -341,25 +414,27 @@ static bool emitClass(clang::CXXRecordDecl *decl, clang::CompilerInstance &ci, } else if (decl->getNumBases() == 1) { // handle some known bases that we've already found a wap to bind clang::CXXBaseSpecifier baseClass = *decl->bases_begin(); - std::string baseName = baseClass.getType().getAsString(getPrintingPolicy()); + clang::QualType baseType = baseClass.getType(); + std::string baseName = getCppClass(baseType->getAsCXXRecordDecl()); // TODO(max): these could be lookups on the corresponding recorddecls using // sema... if (baseName.rfind("mlir::Op<", 0) == 0) { - className = llvm::formatv("{0}, mlir::OpState", className); + cppClass = llvm::formatv("{0}, mlir::OpState", cppClass); } else if (baseName.rfind("mlir::detail::StorageUserBase<", 0) == 0) { llvm::SmallVector templParams; llvm::StringRef{baseName}.split(templParams, ","); - className = llvm::formatv("{0}, {1}", className, templParams[1]); + // TODO(max): this needs to use getCppClass not templParams[1], which is a + // string + cppClass = llvm::formatv("{0}, {1}", cppClass, templParams[1]); } else if (baseName.rfind("mlir::Dialect", 0) == 0 && - className.rfind("mlir::ExtensibleDialect") == - std::string::npos) { + cppClass.rfind("mlir::ExtensibleDialect") == std::string::npos) { // clang-format off - additional += llvm::formatv("\n .def_static(\"insert_into_registry\", [](mlir::DialectRegistry ®istry) {{ registry.insert<{0}>(); })", className); - additional += llvm::formatv("\n .def_static(\"load_into_context\", [](mlir::MLIRContext &context) {{ return context.getOrLoadDialect<{0}>(); })", className); + additional += llvm::formatv("\n .def_static(\"insert_into_registry\", [](mlir::DialectRegistry ®istry) {{ registry.insert<{0}>(); })", cppClass); + additional += llvm::formatv("\n .def_static(\"load_into_context\", [](mlir::MLIRContext &context) {{ return context.getOrLoadDialect<{0}>(); })", cppClass); // clang-format on } else if (!llvm::isa( baseClass.getType()->getAsCXXRecordDecl())) { - className = llvm::formatv("{0}, {1}", className, baseName); + cppClass = llvm::formatv("{0}, {1}", cppClass, baseName); } else { assert(llvm::isa( baseClass.getType()->getAsCXXRecordDecl()) && @@ -372,12 +447,11 @@ static bool emitClass(clang::CXXRecordDecl *decl, clang::CompilerInstance &ci, } } - std::string autoVar = llvm::formatv( - "auto {0}", getNBBindClassName(decl->getQualifiedNameAsString())); - + std::string scope = getNBScope(decl); + std::string pyClassName = getPyClassName(decl->getNameAsString()); outputFile->os() << llvm::formatv( - "\n{0} = nb::class_<{1}>({2}, \"{3}\"){4};\n", autoVar, className, scope, - getPyClassName(decl->getNameAsString()), additional); + "\n{0} = nb::class_<{1}>({2}, \"{3}\"){4};\n", autoVar, cppClass, scope, + pyClassName, additional); return true; } @@ -396,10 +470,9 @@ static bool emitEnum(clang::EnumDecl *decl, clang::CompilerInstance &ci, cstDecl->getQualifiedNameAsString()); if (i++ < nDecls - 1) outputFile->os() << "\n"; - else - outputFile->os() << ";\n"; } - outputFile->os() << "\n"; + + outputFile->os() << ";\n"; return true; } @@ -422,8 +495,7 @@ static bool emitField(clang::DeclaratorDecl *field, clang::CompilerInstance &ci, if (field->getType()->hasPointerRepresentation()) refInternal = ", nb::rv_policy::reference_internal"; - std::string scope = - getNBBindClassName(parentRecord->getQualifiedNameAsString()); + std::string scope = getNBBindClassName(getCppClass(parentRecord)); std::string nbFnName = llvm::formatv("\"{0}\"", snakeCase(field->getNameAsString())); outputFile->os() << llvm::formatv("{0}.{1}({2}, &{3}{4});\n", scope, defStr, @@ -433,7 +505,7 @@ static bool emitField(clang::DeclaratorDecl *field, clang::CompilerInstance &ci, } template -static bool shouldSkip(T *decl) { +static bool shouldSkip(T *decl, clang::CompilerInstance &ci) { auto *encl = llvm::dyn_cast( decl->getEnclosingNamespaceContext()); if (!encl) @@ -442,6 +514,8 @@ static bool shouldSkip(T *decl) { // bind std:: if (encl->isStdNamespace() || encl->isInStdNamespace()) return true; + if (ci.getSema().getSourceManager().isInSystemHeader(decl->getLocation())) + return true; if (!filterInNamespace(encl->getQualifiedNameAsString())) return true; if constexpr (std::is_same_v || @@ -474,7 +548,9 @@ struct BindingsVisitor ci.getDiagnostics())) {} bool VisitCXXRecordDecl(clang::CXXRecordDecl *decl) { - if (shouldSkip(decl)) + if (shouldSkip(decl, ci)) + return true; + if (decl->isAbstract()) return true; if (decl->isClass() || decl->isStruct()) { if (emitClass(decl, ci, outputFile)) @@ -537,8 +613,10 @@ struct BindingsVisitor return true; } + // TODO(max): skip definitions somehow? like FloatType::getFloat4E2M1FN which + // has both the decl and the impl in a header? bool VisitCXXMethodDecl(clang::CXXMethodDecl *decl) { - if (shouldSkip(decl) || llvm::isa(decl) || + if (shouldSkip(decl, ci) || llvm::isa(decl) || !visitedRecords.contains(decl->getParent())) return true; if (decl->isTemplated() || decl->isTemplateDecl() || @@ -557,12 +635,18 @@ struct BindingsVisitor "friend functions not supported")); return true; } + if (decl->isCopyAssignmentOperator() || decl->isMoveAssignmentOperator()) + return true; + if (decl->isDeleted()) + return true; + emitClassMethodOrFunction(decl, ci, outputFile); + return true; } bool VisitFunctionDecl(clang::FunctionDecl *decl) { - if (shouldSkip(decl) || decl->isCXXClassMember()) + if (shouldSkip(decl, ci) || decl->isCXXClassMember()) return true; // clang-format off // this @@ -580,12 +664,19 @@ struct BindingsVisitor "template functions not supported yet")); return true; } + if (decl->getFriendObjectKind()) { + clang::DiagnosticBuilder builder = ci.getDiagnostics().Report( + decl->getLocation(), ci.getDiagnostics().getCustomDiagID( + clang::DiagnosticsEngine::Note, + "friend functions not supported")); + return true; + } emitClassMethodOrFunction(decl, ci, outputFile); return true; } bool VisitEnumDecl(clang::EnumDecl *decl) { - if (shouldSkip(decl)) + if (shouldSkip(decl, ci)) return true; if (decl->getQualifiedNameAsString().rfind("unnamed enum") != std::string::npos) @@ -610,6 +701,8 @@ struct BindingsVisitor // TODO(max): this is a hack and not stable bool VisitDecl(clang::Decl *decl) { + if (ci.getSema().getSourceManager().isInSystemHeader(decl->getLocation())) + return true; const clang::DeclContext *declContext = decl->getDeclContext(); HackDeclContext *ctx = static_cast(decl->getDeclContext()); @@ -799,13 +892,15 @@ namespace nb = nanobind; using namespace nb::literals; using namespace mlir; using namespace llvm; -#include "type_casters.h" +#include "eudsl/type_casters.h" +namespace eudsl { void populate)" << finalTarget << i << R"(Module(nb::module_ &m) { )"; // clang-format on shardFile << shards[i] << std::endl; shardFile << "}" << std::endl; + shardFile << "}" << std::endl; shardFile.flush(); shardFile.close(); } @@ -847,6 +942,7 @@ void populate)" << finalTarget << i << R"(Module(nb::module_ &m) { namespace nb = nanobind; using namespace nb::literals; +namespace eudsl { void populate)" << finalTarget << R"(Module(nb::module_ &m) { )"; // clang-format on @@ -859,6 +955,7 @@ void populate)" << finalTarget << R"(Module(nb::module_ &m) { finalShardedFile << "populate" << finalTarget << i << "Module(m);" << std::endl; + finalShardedFile << "}" << std::endl; finalShardedFile << "}" << std::endl; finalShardedFile.flush(); finalShardedFile.close(); diff --git a/projects/eudsl-py/src/eudslpy_ext.cpp b/projects/eudsl-py/src/eudslpy_ext.cpp index 0b3fc7dc..c66fff48 100644 --- a/projects/eudsl-py/src/eudslpy_ext.cpp +++ b/projects/eudsl-py/src/eudslpy_ext.cpp @@ -3,12 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Copyright (c) 2024. -#include -#include -#include -#include -#include - #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -57,12 +51,16 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ThreadPool.h" -#include "bind_vec_like.h" -#include "type_casters.h" +#include + +// ReSharper disable once CppUnusedIncludeDirective +#include "eudsl/type_casters.h" +#include "eudsl/util.h" namespace nb = nanobind; using namespace nb::literals; +namespace eudsl { class FakeDialect : public mlir::Dialect { public: FakeDialect(llvm::StringRef name, mlir::MLIRContext *context, mlir::TypeID id) @@ -73,63 +71,7 @@ nb::class_<_SmallVector> smallVector; nb::class_<_ArrayRef> arrayRef; nb::class_<_MutableArrayRef> mutableArrayRef; -void bind_array_ref_smallvector(nb::handle scope) { - scope.attr("T") = nb::type_var("T"); - arrayRef = nb::class_<_ArrayRef>(scope, "ArrayRef", nb::is_generic(), - nb::sig("class ArrayRef[T]")); - mutableArrayRef = - nb::class_<_MutableArrayRef>(scope, "MutableArrayRef", nb::is_generic(), - nb::sig("class MutableArrayRef[T]")); - smallVector = nb::class_<_SmallVector>(scope, "SmallVector", nb::is_generic(), - nb::sig("class SmallVector[T]")); -} - -template -struct non_copying_non_moving_class_ : nb::class_ { - template - NB_INLINE non_copying_non_moving_class_(nb::handle scope, const char *name, - const Extra &...extra) { - nb::detail::type_init_data d; - - d.flags = 0; - d.align = (uint8_t)alignof(typename nb::class_::Alias); - d.size = (uint32_t)sizeof(typename nb::class_::Alias); - d.name = name; - d.scope = scope.ptr(); - d.type = &typeid(T); - - if constexpr (!std::is_same_v::Base, T>) { - d.base = &typeid(typename nb::class_::Base); - d.flags |= (uint32_t)nb::detail::type_init_flags::has_base; - } - - if constexpr (std::is_destructible_v) { - d.flags |= (uint32_t)nb::detail::type_flags::is_destructible; - - if constexpr (!std::is_trivially_destructible_v) { - d.flags |= (uint32_t)nb::detail::type_flags::has_destruct; - d.destruct = nb::detail::wrap_destruct; - } - } - - if constexpr (nb::detail::has_shared_from_this_v) { - d.flags |= (uint32_t)nb::detail::type_flags::has_shared_from_this; - d.keep_shared_from_this_alive = [](PyObject *self) noexcept { - if (auto sp = nb::inst_ptr(self)->weak_from_this().lock()) { - nb::detail::keep_alive( - self, new auto(std::move(sp)), - [](void *p) noexcept { delete (decltype(sp) *)p; }); - return true; - } - return false; - }; - } - - (nb::detail::type_extra_apply(d, extra), ...); - - this->m_ptr = nb::detail::nb_type_new(&d); - } -}; +extern void populateEUDSLGen_IR0Module(nb::module_ &m); void populateIRModule(nb::module_ &m) { using namespace mlir; @@ -358,8 +300,10 @@ extern void populateEUDSLGen_x86vectorModule(nb::module_ &m); // extern void populateEUDSLGen_xegpuModule(nb::module_ &m); +} // namespace eudsl + NB_MODULE(eudslpy_ext, m) { - bind_array_ref_smallvector(m); + eudsl::bind_array_ref_smallvector(m); nb::class_(m, "APFloat"); nb::class_(m, "APInt"); @@ -381,9 +325,6 @@ NB_MODULE(eudslpy_ext, m) { nb::class_(m, "TypeID"); nb::class_(m, "InterfaceMap"); - auto irModule = m.def_submodule("ir"); - populateIRModule(irModule); - nb::class_>(m, "FailureOr[bool]"); nb::class_>(m, "FailureOr[StringAttr]"); nb::class_>( @@ -429,64 +370,64 @@ NB_MODULE(eudslpy_ext, m) { nb::class_(m, "BitVector"); auto [smallVectorOfBool, arrayRefOfBool, mutableArrayRefOfBool] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfFloat, arrayRefOfFloat, mutableArrayRefOfFloat] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfInt, arrayRefOfInt, mutableArrayRefOfInt] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfChar, arrayRefOfChar, mutableArrayRefOfChar] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfDouble, arrayRefOfDouble, mutableArrayRefOfDouble] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfInt16, arrayRefOfInt16, mutableArrayRefOfInt16] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfInt32, arrayRefOfInt32, mutableArrayRefOfInt32] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfInt64, arrayRefOfInt64, mutableArrayRefOfInt64] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfUInt16, arrayRefOfUInt16, mutableArrayRefOfUInt16] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfUInt32, arrayRefOfUInt32, mutableArrayRefOfUInt32] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfUInt64, arrayRefOfUInt64, mutableArrayRefOfUInt64] = - bind_array_ref(m); + eudsl::bind_array_ref(m); // these have to precede... - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - - bind_array_ref(m); - - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - - bind_array_ref(m); - bind_array_ref(m); - // bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - - smallVector.def_static( + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + + eudsl::bind_array_ref(m); + + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + // eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + + eudsl::smallVector.def_static( "__class_getitem__", [smallVectorOfBool, smallVectorOfInt, smallVectorOfFloat, smallVectorOfInt16, smallVectorOfInt32, smallVectorOfInt64, @@ -533,7 +474,7 @@ NB_MODULE(eudslpy_ext, m) { throw std::runtime_error(errMsg); }); - smallVector.def_static( + eudsl::smallVector.def_static( "__class_getitem__", [smallVectorOfFloat, smallVectorOfInt16, smallVectorOfInt32, smallVectorOfInt64, smallVectorOfUInt16, smallVectorOfUInt32, @@ -574,17 +515,21 @@ NB_MODULE(eudslpy_ext, m) { nb::class_>( m, "iterator_range[ResultRange.UseIterator]"); - bind_iter_range, mlir::Type>( + eudsl::bind_iter_range, mlir::Type>( m, "ValueTypeRange[ValueRange]"); - bind_iter_range, mlir::Type>( + eudsl::bind_iter_range, mlir::Type>( m, "ValueTypeRange[OperandRange]"); - bind_iter_range, mlir::Type>( + eudsl::bind_iter_range, mlir::Type>( m, "ValueTypeRange[ResultRange]"); - bind_iter_like, nb::rv_policy::reference_internal>( - m, "iplist[Block]"); - bind_iter_like, - nb::rv_policy::reference_internal>(m, "iplist[Operation]"); + eudsl::bind_iter_like, + nb::rv_policy::reference_internal>(m, "iplist[Block]"); + eudsl::bind_iter_like, + nb::rv_policy::reference_internal>(m, + "iplist[Operation]"); + + auto irModule = m.def_submodule("ir"); + eudsl::populateIRModule(irModule); auto dialectsModule = m.def_submodule("dialects"); @@ -592,134 +537,134 @@ NB_MODULE(eudslpy_ext, m) { // populateEUDSLGen_accModule(accModule); auto affineModule = dialectsModule.def_submodule("affine"); - populateEUDSLGen_affineModule(affineModule); + eudsl::populateEUDSLGen_affineModule(affineModule); auto amdgpuModule = dialectsModule.def_submodule("amdgpu"); - populateEUDSLGen_amdgpuModule(amdgpuModule); + eudsl::populateEUDSLGen_amdgpuModule(amdgpuModule); // auto amxModule = dialectsModule.def_submodule("amx"); - // populateEUDSLGen_amxModule(amxModule); + // eudsl::populateEUDSLGen_amxModule(amxModule); auto arithModule = dialectsModule.def_submodule("arith"); - populateEUDSLGen_arithModule(arithModule); + eudsl::populateEUDSLGen_arithModule(arithModule); // auto arm_neonModule = dialectsModule.def_submodule("arm_neon"); - // populateEUDSLGen_arm_neonModule(arm_neonModule); + // eudsl::populateEUDSLGen_arm_neonModule(arm_neonModule); // auto arm_smeModule = dialectsModule.def_submodule("arm_sme"); - // populateEUDSLGen_arm_smeModule(arm_smeModule); + // eudsl::populateEUDSLGen_arm_smeModule(arm_smeModule); // auto arm_sveModule = dialectsModule.def_submodule("arm_sve"); - // populateEUDSLGen_arm_sveModule(arm_sveModule); + // eudsl::populateEUDSLGen_arm_sveModule(arm_sveModule); auto asyncModule = dialectsModule.def_submodule("async"); - populateEUDSLGen_asyncModule(asyncModule); + eudsl::populateEUDSLGen_asyncModule(asyncModule); auto bufferizationModule = dialectsModule.def_submodule("bufferization"); - populateEUDSLGen_bufferizationModule(bufferizationModule); + eudsl::populateEUDSLGen_bufferizationModule(bufferizationModule); auto cfModule = dialectsModule.def_submodule("cf"); - populateEUDSLGen_cfModule(cfModule); + eudsl::populateEUDSLGen_cfModule(cfModule); auto complexModule = dialectsModule.def_submodule("complex"); - populateEUDSLGen_complexModule(complexModule); + eudsl::populateEUDSLGen_complexModule(complexModule); // auto DLTIDialectModule = dialectsModule.def_submodule("DLTIDialect"); - // populateEUDSLGen_DLTIDialectModule(DLTIDialectModule); + // eudsl::populateEUDSLGen_DLTIDialectModule(DLTIDialectModule); auto emitcModule = dialectsModule.def_submodule("emitc"); - populateEUDSLGen_emitcModule(emitcModule); + eudsl::populateEUDSLGen_emitcModule(emitcModule); auto funcModule = dialectsModule.def_submodule("func"); - populateEUDSLGen_funcModule(funcModule); + eudsl::populateEUDSLGen_funcModule(funcModule); auto gpuModule = dialectsModule.def_submodule("gpu"); - populateEUDSLGen_gpuModule(gpuModule); + eudsl::populateEUDSLGen_gpuModule(gpuModule); auto indexModule = dialectsModule.def_submodule("index"); - populateEUDSLGen_indexModule(indexModule); + eudsl::populateEUDSLGen_indexModule(indexModule); // auto irdlModule = dialectsModule.def_submodule("irdl"); - // populateEUDSLGen_irdlModule(irdlModule); + // eudsl::populateEUDSLGen_irdlModule(irdlModule); auto linalgModule = dialectsModule.def_submodule("linalg"); - populateEUDSLGen_linalgModule(linalgModule); + eudsl::populateEUDSLGen_linalgModule(linalgModule); auto LLVMModule = dialectsModule.def_submodule("LLVM"); - populateEUDSLGen_LLVMModule(LLVMModule); + eudsl::populateEUDSLGen_LLVMModule(LLVMModule); auto mathModule = dialectsModule.def_submodule("math"); - populateEUDSLGen_mathModule(mathModule); + eudsl::populateEUDSLGen_mathModule(mathModule); auto memrefModule = dialectsModule.def_submodule("memref"); - populateEUDSLGen_memrefModule(memrefModule); + eudsl::populateEUDSLGen_memrefModule(memrefModule); // auto meshModule = dialectsModule.def_submodule("mesh"); - // populateEUDSLGen_meshModule(meshModule); + // eudsl::populateEUDSLGen_meshModule(meshModule); // auto ml_programModule = dialectsModule.def_submodule("ml_program"); - // populateEUDSLGen_ml_programModule(ml_programModule); + // eudsl::populateEUDSLGen_ml_programModule(ml_programModule); // auto mpiModule = dialectsModule.def_submodule("mpi"); - // populateEUDSLGen_mpiModule(mpiModule); + // eudsl::populateEUDSLGen_mpiModule(mpiModule); auto nvgpuModule = dialectsModule.def_submodule("nvgpu"); - populateEUDSLGen_nvgpuModule(nvgpuModule); + eudsl::populateEUDSLGen_nvgpuModule(nvgpuModule); auto NVVMModule = dialectsModule.def_submodule("NVVM"); - populateEUDSLGen_NVVMModule(NVVMModule); + eudsl::populateEUDSLGen_NVVMModule(NVVMModule); // auto ompModule = dialectsModule.def_submodule("omp"); - // populateEUDSLGen_ompModule(ompModule); + // eudsl::populateEUDSLGen_ompModule(ompModule); auto pdlModule = dialectsModule.def_submodule("pdl"); - populateEUDSLGen_pdlModule(pdlModule); + eudsl::populateEUDSLGen_pdlModule(pdlModule); auto pdl_interpModule = dialectsModule.def_submodule("pdl_interp"); - populateEUDSLGen_pdl_interpModule(pdl_interpModule); + eudsl::populateEUDSLGen_pdl_interpModule(pdl_interpModule); auto polynomialModule = dialectsModule.def_submodule("polynomial"); - populateEUDSLGen_polynomialModule(polynomialModule); + eudsl::populateEUDSLGen_polynomialModule(polynomialModule); // auto ptrModule = dialectsModule.def_submodule("ptr"); - // populateEUDSLGen_ptrModule(ptrModule); + // eudsl::populateEUDSLGen_ptrModule(ptrModule); // auto quantModule = dialectsModule.def_submodule("quant"); - // populateEUDSLGen_quantModule(quantModule); + // eudsl::populateEUDSLGen_quantModule(quantModule); auto ROCDLModule = dialectsModule.def_submodule("ROCDL"); - populateEUDSLGen_ROCDLModule(ROCDLModule); + eudsl::populateEUDSLGen_ROCDLModule(ROCDLModule); auto scfModule = dialectsModule.def_submodule("scf"); - populateEUDSLGen_scfModule(scfModule); + eudsl::populateEUDSLGen_scfModule(scfModule); auto shapeModule = dialectsModule.def_submodule("shape"); - populateEUDSLGen_shapeModule(shapeModule); + eudsl::populateEUDSLGen_shapeModule(shapeModule); // auto sparse_tensorModule = dialectsModule.def_submodule("sparse_tensor"); - // populateEUDSLGen_sparse_tensorModule(sparse_tensorModule); + // eudsl::populateEUDSLGen_sparse_tensorModule(sparse_tensorModule); // auto spirvModule = dialectsModule.def_submodule("spirv"); - // populateEUDSLGen_spirvModule(spirvModule); + // eudsl::populateEUDSLGen_spirvModule(spirvModule); auto tensorModule = dialectsModule.def_submodule("tensor"); - populateEUDSLGen_tensorModule(tensorModule); + eudsl::populateEUDSLGen_tensorModule(tensorModule); auto tosaModule = dialectsModule.def_submodule("tosa"); - populateEUDSLGen_tosaModule(tosaModule); + eudsl::populateEUDSLGen_tosaModule(tosaModule); // auto transformModule = dialectsModule.def_submodule("transform"); - // populateEUDSLGen_transformModule(transformModule); + // eudsl::populateEUDSLGen_transformModule(transformModule); // auto ubModule = dialectsModule.def_submodule("ub"); - // populateEUDSLGen_ubModule(ubModule); + // eudsl::populateEUDSLGen_ubModule(ubModule); // auto vectorModule = dialectsModule.def_submodule("vector"); - // populateEUDSLGen_vectorModule(vectorModule); + // eudsl::populateEUDSLGen_vectorModule(vectorModule); // auto x86vectorModule = dialectsModule.def_submodule("x86vector"); - // populateEUDSLGen_x86vectorModule(x86vectorModule); + // eudsl::populateEUDSLGen_x86vectorModule(x86vectorModule); // auto xegpuModule = dialectsModule.def_submodule("xegpu"); - // populateEUDSLGen_xegpuModule(xegpuModule); + // eudsl::populateEUDSLGen_xegpuModule(xegpuModule); } diff --git a/projects/eudsl-tblgen/CMakeLists.txt b/projects/eudsl-tblgen/CMakeLists.txt index a3ac78bc..3b6bce8c 100644 --- a/projects/eudsl-tblgen/CMakeLists.txt +++ b/projects/eudsl-tblgen/CMakeLists.txt @@ -29,6 +29,8 @@ if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_LIST_DIR) include(AddLLVM) include(AddMLIR) include(HandleLLVMOptions) + + include_directories(${CMAKE_CURRENT_LIST_DIR}/../common) endif() include_directories(${MLIR_INCLUDE_DIRS}) @@ -58,6 +60,7 @@ nanobind_add_module(eudsl_tblgen_ext NB_STATIC src/TGParser.cpp src/TGLexer.cpp ) +set_property(TARGET eudsl_tblgen_ext PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(eudsl_tblgen_ext PRIVATE LLVMTableGenCommon LLVMTableGen MLIRTableGen) set(nanobind_options diff --git a/projects/eudsl-tblgen/src/TGLexer.cpp b/projects/eudsl-tblgen/src/TGLexer.cpp index ff1c73a4..961831f3 100644 --- a/projects/eudsl-tblgen/src/TGLexer.cpp +++ b/projects/eudsl-tblgen/src/TGLexer.cpp @@ -15,14 +15,11 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Twine.h" -#include "llvm/Config/llvm-config.h" // for strtoull()/strtoll() define #include "llvm/Support/Compiler.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/TableGen/Error.h" -#include #include -#include #include #include #include diff --git a/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp b/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp index 99bc8a5f..a86884f9 100644 --- a/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp +++ b/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp @@ -33,91 +33,41 @@ #include #include -#include +// ReSharper disable once CppUnusedIncludeDirective #include #include +// ReSharper disable once CppUnusedIncludeDirective #include -using namespace llvm; +#include "eudsl/util.h" +// ReSharper disable once CppUnusedIncludeDirective +#include "eudsl/type_casters.h" namespace nb = nanobind; using namespace nb::literals; -template -constexpr auto coerceReturn(Return (*pf)(Args...)) noexcept { - return [&pf](Args &&...args) -> NewReturn { - return pf(std::forward(args)...); - }; -} - -template -constexpr auto coerceReturn(Return (Class::*pmf)(Args...), - std::false_type = {}) noexcept { - return [&pmf](Class *cls, Args &&...args) -> NewReturn { - return (cls->*pmf)(std::forward(args)...); - }; -} - -/* - * If you get - * ``` - * Called object type 'void(MyClass::*)(vector&,int)' is not a function or - * function pointer - * ``` - * it's because you're calling a member function without - * passing the `this` pointer as the first arg - */ -template -constexpr auto coerceReturn(Return (Class::*pmf)(Args...) const, - std::true_type) noexcept { - // copy the *pmf, not capture by ref - return [pmf](const Class &cls, Args &&...args) -> NewReturn { - return (cls.*pmf)(std::forward(args)...); - }; -} - -template <> -struct nb::detail::type_caster { - NB_TYPE_CASTER(StringRef, const_name("str")) - - bool from_python(handle src, uint8_t, cleanup_list *) noexcept { - Py_ssize_t size; - const char *str = PyUnicode_AsUTF8AndSize(src.ptr(), &size); - if (!str) { - PyErr_Clear(); - return false; - } - value = StringRef(str, (size_t)size); - return true; - } - - static handle from_cpp(StringRef value, rv_policy, cleanup_list *) noexcept { - return PyUnicode_FromStringAndSize(value.data(), value.size()); - } -}; - // hack to expose protected Init::InitKind -struct HackInit : public Init { +struct HackInit : public llvm::Init { using InitKind = Init::InitKind; }; NB_MODULE(eudsl_tblgen_ext, m) { - auto recty = nb::class_(m, "RecTy"); + auto recty = nb::class_(m, "RecTy"); - nb::enum_(m, "RecTyKind") - .value("BitRecTyKind", RecTy::RecTyKind::BitRecTyKind) - .value("BitsRecTyKind", RecTy::RecTyKind::BitsRecTyKind) - .value("IntRecTyKind", RecTy::RecTyKind::IntRecTyKind) - .value("StringRecTyKind", RecTy::RecTyKind::StringRecTyKind) - .value("ListRecTyKind", RecTy::RecTyKind::ListRecTyKind) - .value("DagRecTyKind", RecTy::RecTyKind::DagRecTyKind) - .value("RecordRecTyKind", RecTy::RecTyKind::RecordRecTyKind); + nb::enum_(m, "RecTyKind") + .value("BitRecTyKind", llvm::RecTy::RecTyKind::BitRecTyKind) + .value("BitsRecTyKind", llvm::RecTy::RecTyKind::BitsRecTyKind) + .value("IntRecTyKind", llvm::RecTy::RecTyKind::IntRecTyKind) + .value("StringRecTyKind", llvm::RecTy::RecTyKind::StringRecTyKind) + .value("ListRecTyKind", llvm::RecTy::RecTyKind::ListRecTyKind) + .value("DagRecTyKind", llvm::RecTy::RecTyKind::DagRecTyKind) + .value("RecordRecTyKind", llvm::RecTy::RecTyKind::RecordRecTyKind); recty.def("get_rec_ty_kind", &llvm::RecTy::getRecTyKind) .def("get_record_keeper", &llvm::RecTy::getRecordKeeper, nb::rv_policy::reference_internal) .def("get_as_string", &llvm::RecTy::getAsString) - .def("__str__", &RecTy::getAsString) + .def("__str__", &llvm::RecTy::getAsString) .def("print", &llvm::RecTy::print, "os"_a) .def("dump", &llvm::RecTy::dump) .def("type_is_convertible_to", &llvm::RecTy::typeIsConvertibleTo, "rhs"_a) @@ -130,7 +80,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def_static("get", &llvm::BitRecTy::get, "rk"_a, nb::rv_policy::reference_internal) .def("get_as_string", &llvm::BitRecTy::getAsString) - .def("__str__", &BitRecTy::getAsString) + .def("__str__", &llvm::BitRecTy::getAsString) .def("type_is_convertible_to", &llvm::BitRecTy::typeIsConvertibleTo, "rhs"_a); @@ -171,7 +121,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_as_string", &llvm::DagRecTy::getAsString) .def("__init__", &llvm::DagRecTy::getAsString); - nb::class_(m, "RecordRecTy") + nb::class_(m, "RecordRecTy") .def_static("classof", &llvm::RecordRecTy::classof, "rt"_a) .def_static( "get", @@ -189,8 +139,8 @@ NB_MODULE(eudsl_tblgen_ext, m) { "class"_a, nb::rv_policy::reference_internal) .def("profile", &llvm::RecordRecTy::Profile, "id"_a) .def("get_classes", - coerceReturn>(&RecordRecTy::getClasses, - nb::const_), + eudsl::coerceReturn>( + &llvm::RecordRecTy::getClasses, nb::const_), nb::rv_policy::reference_internal) .def("classes_begin", &llvm::RecordRecTy::classes_begin, nb::rv_policy::reference_internal) @@ -238,7 +188,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("is_concrete", &llvm::Init::isConcrete) .def("print", &llvm::Init::print, "os"_a) .def("get_as_string", &llvm::Init::getAsString) - .def("__str__", &Init::getAsUnquotedString) + .def("__str__", &llvm::Init::getAsUnquotedString) .def("get_as_unquoted_string", &llvm::Init::getAsUnquotedString) .def("dump", &llvm::Init::dump) .def("get_cast_to", &llvm::Init::getCastTo, "ty"_a, @@ -336,7 +286,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_as_string", &llvm::BitInit::getAsString) .def("__str__", &llvm::BitInit::getAsString); - nb::class_(m, "BitsInit") + nb::class_(m, "BitsInit") .def_static("classof", &llvm::BitsInit::classof, "i"_a) .def_static("get", &llvm::BitsInit::get, "rk"_a, "range"_a, nb::rv_policy::reference_internal) @@ -414,7 +364,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { nb::rv_policy::reference_internal); auto llvm_ListInit = - nb::class_(m, "ListInit") + nb::class_(m, "ListInit") .def_static("classof", &llvm::ListInit::classof, "i"_a) .def_static("get", &llvm::ListInit::get, "range"_a, "elt_ty"_a, nb::rv_policy::reference_internal) @@ -439,23 +389,24 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("empty", &llvm::ListInit::empty) .def("get_bit", &llvm::ListInit::getBit, "bit"_a, nb::rv_policy::reference_internal) - .def("__len__", [](const ListInit &v) { return v.size(); }) - .def("__bool__", [](const ListInit &v) { return !v.empty(); }) + .def("__len__", [](const llvm::ListInit &v) { return v.size(); }) + .def("__bool__", [](const llvm::ListInit &v) { return !v.empty(); }) .def( "__iter__", - [](ListInit &v) { + [](llvm::ListInit &v) { return nb::make_iterator( - nb::type(), "Iterator", v.begin(), v.end()); + nb::type(), "Iterator", v.begin(), v.end()); }, nb::rv_policy::reference_internal) .def( "__getitem__", - [](ListInit &v, Py_ssize_t i) { - return v.getElement(nb::detail::wrap(i, v.size())); + [](llvm::ListInit &v, Py_ssize_t i) { + return v.getElement(eudsl::wrap(i, v.size())); }, nb::rv_policy::reference_internal) - .def("get_values", coerceReturn>( - &ListInit::getValues, nb::const_)); + .def("get_values", + eudsl::coerceReturn>( + &llvm::ListInit::getValues, nb::const_)); auto llvm_OpInit = nb::class_(m, "OpInit") .def_static("classof", &llvm::OpInit::classof, "i"_a) @@ -467,20 +418,20 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_bit", &llvm::OpInit::getBit, "bit"_a, nb::rv_policy::reference_internal); - auto unaryOpInit = nb::class_(m, "UnOpInit"); - nb::enum_(m, "UnaryOp") - .value("TOLOWER", UnOpInit::UnaryOp::TOLOWER) - .value("TOUPPER", UnOpInit::UnaryOp::TOUPPER) - .value("CAST", UnOpInit::UnaryOp::CAST) - .value("NOT", UnOpInit::UnaryOp::NOT) - .value("HEAD", UnOpInit::UnaryOp::HEAD) - .value("TAIL", UnOpInit::UnaryOp::TAIL) - .value("SIZE", UnOpInit::UnaryOp::SIZE) - .value("EMPTY", UnOpInit::UnaryOp::EMPTY) - .value("GETDAGOP", UnOpInit::UnaryOp::GETDAGOP) - .value("LOG2", UnOpInit::UnaryOp::LOG2) - .value("REPR", UnOpInit::UnaryOp::REPR) - .value("LISTFLATTEN", UnOpInit::UnaryOp::LISTFLATTEN); + auto unaryOpInit = nb::class_(m, "UnOpInit"); + nb::enum_(m, "UnaryOp") + .value("TOLOWER", llvm::UnOpInit::UnaryOp::TOLOWER) + .value("TOUPPER", llvm::UnOpInit::UnaryOp::TOUPPER) + .value("CAST", llvm::UnOpInit::UnaryOp::CAST) + .value("NOT", llvm::UnOpInit::UnaryOp::NOT) + .value("HEAD", llvm::UnOpInit::UnaryOp::HEAD) + .value("TAIL", llvm::UnOpInit::UnaryOp::TAIL) + .value("SIZE", llvm::UnOpInit::UnaryOp::SIZE) + .value("EMPTY", llvm::UnOpInit::UnaryOp::EMPTY) + .value("GETDAGOP", llvm::UnOpInit::UnaryOp::GETDAGOP) + .value("LOG2", llvm::UnOpInit::UnaryOp::LOG2) + .value("REPR", llvm::UnOpInit::UnaryOp::REPR) + .value("LISTFLATTEN", llvm::UnOpInit::UnaryOp::LISTFLATTEN); unaryOpInit.def_static("classof", &llvm::UnOpInit::classof, "i"_a) .def_static("get", &llvm::UnOpInit::get, "opc"_a, "lhs"_a, "type"_a, @@ -509,36 +460,36 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_as_string", &llvm::UnOpInit::getAsString) .def("__str__", &llvm::UnOpInit::getAsUnquotedString); - auto binaryOpInit = nb::class_(m, "BinOpInit"); - nb::enum_(m, "BinaryOp") - .value("ADD", BinOpInit::BinaryOp::ADD) - .value("SUB", BinOpInit::BinaryOp::SUB) - .value("MUL", BinOpInit::BinaryOp::MUL) - .value("DIV", BinOpInit::BinaryOp::DIV) - .value("AND", BinOpInit::BinaryOp::AND) - .value("OR", BinOpInit::BinaryOp::OR) - .value("XOR", BinOpInit::BinaryOp::XOR) - .value("SHL", BinOpInit::BinaryOp::SHL) - .value("SRA", BinOpInit::BinaryOp::SRA) - .value("SRL", BinOpInit::BinaryOp::SRL) - .value("LISTCONCAT", BinOpInit::BinaryOp::LISTCONCAT) - .value("LISTSPLAT", BinOpInit::BinaryOp::LISTSPLAT) - .value("LISTREMOVE", BinOpInit::BinaryOp::LISTREMOVE) - .value("LISTELEM", BinOpInit::BinaryOp::LISTELEM) - .value("LISTSLICE", BinOpInit::BinaryOp::LISTSLICE) - .value("RANGEC", BinOpInit::BinaryOp::RANGEC) - .value("STRCONCAT", BinOpInit::BinaryOp::STRCONCAT) - .value("INTERLEAVE", BinOpInit::BinaryOp::INTERLEAVE) - .value("CONCAT", BinOpInit::BinaryOp::CONCAT) - .value("EQ", BinOpInit::BinaryOp::EQ) - .value("NE", BinOpInit::BinaryOp::NE) - .value("LE", BinOpInit::BinaryOp::LE) - .value("LT", BinOpInit::BinaryOp::LT) - .value("GE", BinOpInit::BinaryOp::GE) - .value("GT", BinOpInit::BinaryOp::GT) - .value("GETDAGARG", BinOpInit::BinaryOp::GETDAGARG) - .value("GETDAGNAME", BinOpInit::BinaryOp::GETDAGNAME) - .value("SETDAGOP", BinOpInit::BinaryOp::SETDAGOP); + auto binaryOpInit = nb::class_(m, "BinOpInit"); + nb::enum_(m, "BinaryOp") + .value("ADD", llvm::BinOpInit::BinaryOp::ADD) + .value("SUB", llvm::BinOpInit::BinaryOp::SUB) + .value("MUL", llvm::BinOpInit::BinaryOp::MUL) + .value("DIV", llvm::BinOpInit::BinaryOp::DIV) + .value("AND", llvm::BinOpInit::BinaryOp::AND) + .value("OR", llvm::BinOpInit::BinaryOp::OR) + .value("XOR", llvm::BinOpInit::BinaryOp::XOR) + .value("SHL", llvm::BinOpInit::BinaryOp::SHL) + .value("SRA", llvm::BinOpInit::BinaryOp::SRA) + .value("SRL", llvm::BinOpInit::BinaryOp::SRL) + .value("LISTCONCAT", llvm::BinOpInit::BinaryOp::LISTCONCAT) + .value("LISTSPLAT", llvm::BinOpInit::BinaryOp::LISTSPLAT) + .value("LISTREMOVE", llvm::BinOpInit::BinaryOp::LISTREMOVE) + .value("LISTELEM", llvm::BinOpInit::BinaryOp::LISTELEM) + .value("LISTSLICE", llvm::BinOpInit::BinaryOp::LISTSLICE) + .value("RANGEC", llvm::BinOpInit::BinaryOp::RANGEC) + .value("STRCONCAT", llvm::BinOpInit::BinaryOp::STRCONCAT) + .value("INTERLEAVE", llvm::BinOpInit::BinaryOp::INTERLEAVE) + .value("CONCAT", llvm::BinOpInit::BinaryOp::CONCAT) + .value("EQ", llvm::BinOpInit::BinaryOp::EQ) + .value("NE", llvm::BinOpInit::BinaryOp::NE) + .value("LE", llvm::BinOpInit::BinaryOp::LE) + .value("LT", llvm::BinOpInit::BinaryOp::LT) + .value("GE", llvm::BinOpInit::BinaryOp::GE) + .value("GT", llvm::BinOpInit::BinaryOp::GT) + .value("GETDAGARG", llvm::BinOpInit::BinaryOp::GETDAGARG) + .value("GETDAGNAME", llvm::BinOpInit::BinaryOp::GETDAGNAME) + .value("SETDAGOP", llvm::BinOpInit::BinaryOp::SETDAGOP); binaryOpInit.def_static("classof", &llvm::BinOpInit::classof, "i"_a) .def_static("get", &llvm::BinOpInit::get, "opc"_a, "lhs"_a, "rhs"_a, @@ -567,18 +518,19 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_as_string", &llvm::BinOpInit::getAsString) .def("__str__", &llvm::BinOpInit::getAsUnquotedString); - auto ternaryOpInit = nb::class_(m, "TernOpInit"); - nb::enum_(m, "TernaryOp") - .value("SUBST", TernOpInit::TernaryOp::SUBST) - .value("FOREACH", TernOpInit::TernaryOp::FOREACH) - .value("FILTER", TernOpInit::TernaryOp::FILTER) - .value("IF", TernOpInit::TernaryOp::IF) - .value("DAG", TernOpInit::TernaryOp::DAG) - .value("RANGE", TernOpInit::TernaryOp::RANGE) - .value("SUBSTR", TernOpInit::TernaryOp::SUBSTR) - .value("FIND", TernOpInit::TernaryOp::FIND) - .value("SETDAGARG", TernOpInit::TernaryOp::SETDAGARG) - .value("SETDAGNAME", TernOpInit::TernaryOp::SETDAGNAME); + auto ternaryOpInit = + nb::class_(m, "TernOpInit"); + nb::enum_(m, "TernaryOp") + .value("SUBST", llvm::TernOpInit::TernaryOp::SUBST) + .value("FOREACH", llvm::TernOpInit::TernaryOp::FOREACH) + .value("FILTER", llvm::TernOpInit::TernaryOp::FILTER) + .value("IF", llvm::TernOpInit::TernaryOp::IF) + .value("DAG", llvm::TernOpInit::TernaryOp::DAG) + .value("RANGE", llvm::TernOpInit::TernaryOp::RANGE) + .value("SUBSTR", llvm::TernOpInit::TernaryOp::SUBSTR) + .value("FIND", llvm::TernOpInit::TernaryOp::FIND) + .value("SETDAGARG", llvm::TernOpInit::TernaryOp::SETDAGARG) + .value("SETDAGNAME", llvm::TernOpInit::TernaryOp::SETDAGNAME); ternaryOpInit.def_static("classof", &llvm::TernOpInit::classof, "i"_a) .def_static("get", &llvm::TernOpInit::get, "opc"_a, "lhs"_a, "mhs"_a, @@ -604,7 +556,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_as_string", &llvm::TernOpInit::getAsString) .def("__str__", &llvm::TernOpInit::getAsUnquotedString); - nb::class_(m, "CondOpInit") + nb::class_(m, "CondOpInit") .def_static("classof", &llvm::CondOpInit::classof, "i"_a) .def_static("get", &llvm::CondOpInit::get, "c"_a, "v"_a, "type"_a, nb::rv_policy::reference_internal) @@ -641,7 +593,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_bit", &llvm::CondOpInit::getBit, "bit"_a, nb::rv_policy::reference_internal); - nb::class_(m, "FoldOpInit") + nb::class_(m, "FoldOpInit") .def_static("classof", &llvm::FoldOpInit::classof, "i"_a) .def_static("get", &llvm::FoldOpInit::get, "start"_a, "list"_a, "a"_a, "b"_a, "expr"_a, "type"_a, nb::rv_policy::reference_internal) @@ -656,7 +608,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_as_string", &llvm::FoldOpInit::getAsString) .def("__str__", &llvm::FoldOpInit::getAsString); - nb::class_(m, "IsAOpInit") + nb::class_(m, "IsAOpInit") .def_static("classof", &llvm::IsAOpInit::classof, "i"_a) .def_static("get", &llvm::IsAOpInit::get, "check_type"_a, "expr"_a, nb::rv_policy::reference_internal) @@ -670,7 +622,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("__str__", &llvm::IsAOpInit::getAsString) .def("get_as_string", &llvm::IsAOpInit::getAsString); - nb::class_(m, "ExistsOpInit") + nb::class_(m, "ExistsOpInit") .def_static("classof", &llvm::ExistsOpInit::classof, "i"_a) .def_static("get", &llvm::ExistsOpInit::get, "check_type"_a, "expr"_a, nb::rv_policy::reference_internal) @@ -685,7 +637,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_as_string", &llvm::ExistsOpInit::getAsString) .def("__str__", &llvm::ExistsOpInit::getAsUnquotedString); - nb::class_(m, "VarInit") + nb::class_(m, "VarInit") .def_static("classof", &llvm::VarInit::classof, "i"_a) .def_static( "get", @@ -708,7 +660,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_as_string", &llvm::VarInit::getAsString) .def("__str__", &llvm::VarInit::getAsUnquotedString); - nb::class_(m, "VarBitInit") + nb::class_(m, "VarBitInit") .def_static("classof", &llvm::VarBitInit::classof, "i"_a) .def_static("get", &llvm::VarBitInit::get, "t"_a, "b"_a, nb::rv_policy::reference_internal) @@ -722,7 +674,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_bit", &llvm::VarBitInit::getBit, "b"_a, nb::rv_policy::reference_internal); - nb::class_(m, "DefInit") + nb::class_(m, "DefInit") .def_static("classof", &llvm::DefInit::classof, "i"_a) .def("convert_initializer_to", &llvm::DefInit::convertInitializerTo, "ty"_a, nb::rv_policy::reference_internal) @@ -735,7 +687,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_bit", &llvm::DefInit::getBit, "bit"_a, nb::rv_policy::reference_internal); - nb::class_(m, "VarDefInit") + nb::class_(m, "VarDefInit") .def_static("classof", &llvm::VarDefInit::classof, "i"_a) .def_static("get", &llvm::VarDefInit::get, "loc"_a, "class"_a, "args"_a, nb::rv_policy::reference_internal) @@ -756,27 +708,28 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_bit", &llvm::VarDefInit::getBit, "bit"_a, nb::rv_policy::reference_internal) .def("args", - coerceReturn>(&VarDefInit::args, - nb::const_), + eudsl::coerceReturn>( + &llvm::VarDefInit::args, nb::const_), nb::rv_policy::reference_internal) - .def("__len__", [](const VarDefInit &v) { return v.args_size(); }) - .def("__bool__", [](const VarDefInit &v) { return !v.args_empty(); }) + .def("__len__", [](const llvm::VarDefInit &v) { return v.args_size(); }) + .def("__bool__", + [](const llvm::VarDefInit &v) { return !v.args_empty(); }) .def( "__iter__", - [](VarDefInit &v) { + [](llvm::VarDefInit &v) { return nb::make_iterator( - nb::type(), "Iterator", v.args_begin(), + nb::type(), "Iterator", v.args_begin(), v.args_end()); }, nb::rv_policy::reference_internal) .def( "__getitem__", - [](VarDefInit &v, Py_ssize_t i) { - return v.getArg(nb::detail::wrap(i, v.args_size())); + [](llvm::VarDefInit &v, Py_ssize_t i) { + return v.getArg(eudsl::wrap(i, v.args_size())); }, nb::rv_policy::reference_internal); - nb::class_(m, "FieldInit") + nb::class_(m, "FieldInit") .def_static("classof", &llvm::FieldInit::classof, "i"_a) .def_static("get", &llvm::FieldInit::get, "r"_a, "fn"_a, nb::rv_policy::reference_internal) @@ -794,7 +747,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_as_string", &llvm::FieldInit::getAsString) .def("__str__", &llvm::FieldInit::getAsUnquotedString); - nb::class_(m, "DagInit") + nb::class_(m, "DagInit") .def("profile", &llvm::DagInit::Profile, "id"_a) .def("get_operator", &llvm::DagInit::getOperator, nb::rv_policy::reference_internal) @@ -829,26 +782,27 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_bit", &llvm::DagInit::getBit, "bit"_a, nb::rv_policy::reference_internal) .def("get_arg_names", - coerceReturn>(&DagInit::getArgNames, - nb::const_), + eudsl::coerceReturn>( + &llvm::DagInit::getArgNames, nb::const_), nb::rv_policy::reference_internal) .def("get_args", - coerceReturn>(&DagInit::getArgs, - nb::const_), + eudsl::coerceReturn>( + &llvm::DagInit::getArgs, nb::const_), nb::rv_policy::reference_internal) - .def("__len__", [](const DagInit &v) { return v.arg_size(); }) - .def("__bool__", [](const DagInit &v) { return !v.arg_empty(); }) + .def("__len__", [](const llvm::DagInit &v) { return v.arg_size(); }) + .def("__bool__", [](const llvm::DagInit &v) { return !v.arg_empty(); }) .def( "__iter__", - [](DagInit &v) { + [](llvm::DagInit &v) { return nb::make_iterator( - nb::type(), "Iterator", v.arg_begin(), v.arg_end()); + nb::type(), "Iterator", v.arg_begin(), + v.arg_end()); }, nb::rv_policy::reference_internal) .def( "__getitem__", - [](DagInit &v, Py_ssize_t i) { - return v.getArg(nb::detail::wrap(i, v.arg_size())); + [](llvm::DagInit &v, Py_ssize_t i) { + return v.getArg(eudsl::wrap(i, v.arg_size())); }, nb::rv_policy::reference_internal); @@ -892,11 +846,11 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("dump", &llvm::RecordVal::dump) .def("print", &llvm::RecordVal::print, "os"_a, "print_sem"_a) .def("__str__", - [](const RecordVal &self) { + [](const llvm::RecordVal &self) { return self.getValue() ? self.getValue()->getAsUnquotedString() : "<>"; }) - .def("is_used", &RecordVal::isUsed); + .def("is_used", &llvm::RecordVal::isUsed); struct RecordValues {}; nb::class_(m, "RecordValues", nb::dynamic_attr()) @@ -907,7 +861,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { int i = 0; for (auto [key, value] : dic) { s += key + nb::str("=") + - nb::str(nb::cast(value) + nb::str(nb::cast(value) .getValue() ->getAsUnquotedString() .c_str()); @@ -943,16 +897,16 @@ NB_MODULE(eudsl_tblgen_ext, m) { }, nb::rv_policy::reference_internal); - nb::class_(m, "Record") + nb::class_(m, "Record") .def("get_direct_super_classes", - [](const Record &self) -> std::vector { - SmallVector Classes; + [](const llvm::Record &self) -> std::vector { + llvm::SmallVector Classes; self.getDirectSuperClasses(Classes); return {Classes.begin(), Classes.end()}; }) .def( "get_values", - [](Record &self) { + [](llvm::Record &self) { // you can't just call the class_->operator() nb::handle recordValsInstTy = nb::type(); assert(recordValsInstTy.is_valid() && @@ -962,8 +916,8 @@ NB_MODULE(eudsl_tblgen_ext, m) { recordValsInst.type().is(recordValsInstTy) && !nb::inst_ready(recordValsInst)); - std::vector values = self.getValues(); - for (const RecordVal &recordVal : values) { + std::vector values = self.getValues(); + for (const llvm::RecordVal &recordVal : values) { nb::setattr(recordValsInst, recordVal.getName().str().c_str(), nb::borrow(nb::cast(recordVal))); } @@ -971,8 +925,8 @@ NB_MODULE(eudsl_tblgen_ext, m) { }, nb::rv_policy::reference_internal) .def("get_template_args", - coerceReturn>(&Record::getTemplateArgs, - nb::const_), + eudsl::coerceReturn>( + &llvm::Record::getTemplateArgs, nb::const_), nb::rv_policy::reference_internal) .def_static("get_new_uid", &llvm::Record::getNewUID, "rk"_a) .def("get_id", &llvm::Record::getID) @@ -1104,8 +1058,9 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_value_as_dag", &llvm::Record::getValueAsDag, "field_name"_a, nb::rv_policy::reference_internal); - using RecordMap = std::map, std::less<>>; - using GlobalMap = std::map>; + using RecordMap = + std::map, std::less<>>; + using GlobalMap = std::map>; nb::bind_map(m, "GlobalMap"); nb::class_(m, "RecordMap") @@ -1141,26 +1096,27 @@ NB_MODULE(eudsl_tblgen_ext, m) { }, nb::rv_policy::reference_internal); - nb::class_(m, "RecordKeeper") + nb::class_(m, "RecordKeeper") .def(nb::init<>()) .def( "parse_td", - [](RecordKeeper &self, const std::string &inputFilename, + [](llvm::RecordKeeper &self, const std::string &inputFilename, const std::vector &includeDirs, const std::vector ¯oNames, bool noWarnOnUnusedTemplateArgs) { - ErrorOr> fileOrErr = - MemoryBuffer::getFileOrSTDIN(inputFilename, /*IsText=*/true); + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename, + /*IsText=*/true); if (std::error_code EC = fileOrErr.getError()) throw std::runtime_error("Could not open input file '" + inputFilename + "': " + EC.message() + "\n"); self.saveInputFilename(inputFilename); - SourceMgr srcMgr; + llvm::SourceMgr srcMgr; srcMgr.setIncludeDirs(includeDirs); - srcMgr.AddNewSourceBuffer(std::move(*fileOrErr), SMLoc()); - TGParser tgParser(srcMgr, macroNames, self, - noWarnOnUnusedTemplateArgs); + srcMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + llvm::TGParser tgParser(srcMgr, macroNames, self, + noWarnOnUnusedTemplateArgs); if (tgParser.ParseFile()) throw std::runtime_error("Could not parse file '" + inputFilename); @@ -1202,8 +1158,8 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("dump", &llvm::RecordKeeper::dump) .def( "get_all_derived_definitions", - [](RecordKeeper &self, - const std::string &className) -> std::vector { + [](llvm::RecordKeeper &self, const std::string &className) + -> std::vector { return self.getAllDerivedDefinitions(className); }, "class_name"_a, nb::rv_policy::reference_internal);