Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Jan 1, 2025
1 parent 483d843 commit 5e2252f
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 91 deletions.
1 change: 1 addition & 0 deletions projects/eudsl-py/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ set(nanobind_options
-Wno-nested-anon-types
-Wno-zero-length-array
-Wno-c++98-compat-extra-semi
-Wno-c++20-extensions
$<$<PLATFORM_ID:Linux>:-fexceptions -frtti>
$<$<PLATFORM_ID:Darwin>:-fexceptions -frtti>
$<$<PLATFORM_ID:Windows>:/EHsc /GR>
Expand Down
109 changes: 51 additions & 58 deletions projects/eudsl-py/src/bind_vec_like.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,48 @@ std::tuple<nanobind::class_<llvm::SmallVector<Element>>,
nanobind::class_<llvm::MutableArrayRef<Element>>>
bind_array_ref(nanobind::handle scope, Args &&...args) {
using ArrayRef = llvm::ArrayRef<Element>;
using SmallVec = llvm::SmallVector<Element>;
using MutableArrayRef = llvm::MutableArrayRef<Element>;
using ValueRef = Element &;

auto vecClName = "SmallVector[" + std::string(typeid(Element).name()) + "]";
nanobind::handle array_cl_cur = nanobind::type<ArrayRef>();
if (array_cl_cur.is_valid()) {
nanobind::handle smallvec_cl_cur = nanobind::type<SmallVec>();
assert(smallvec_cl_cur.is_valid() &&
"expected SmallVec to already have been registered");
nanobind::handle mutable_cl_cur = nanobind::type<MutableArrayRef>();
assert(mutable_cl_cur.is_valid() &&
"expected MutableArrayRef to already have been registered");
return std::make_tuple(
nanobind::borrow<nanobind::class_<SmallVec>>(array_cl_cur),
nanobind::borrow<nanobind::class_<ArrayRef>>(array_cl_cur),
nanobind::borrow<nanobind::class_<MutableArrayRef>>(array_cl_cur));
}

std::string typename_;
if (nanobind::type<Element>().ptr()) {
typename_ =
std::string(nanobind::type_name(nanobind::type<Element>()).c_str());

} else
typename_ = '"' + llvm::getTypeName<Element>().str() + '"';

std::string vecClName = "SmallVector[" + typename_ + "]";
auto _smallVectorOfElement =
nanobind::bind_vector<llvm::SmallVector<Element>>(scope,
vecClName.c_str());
nanobind::bind_vector<SmallVec>(scope, vecClName.c_str());

smallVector.def_static(
"__class_getitem__",
[_smallVectorOfElement](nanobind::type_object_t<Element>) {
return _smallVectorOfElement;
});

auto arrClName = "ArrayRef[" + std::string(typeid(Element).name()) + "]";
std::string arrClName = "ArrayRef[" + typename_ + "]";
auto cl =
nanobind::class_<ArrayRef>(scope, arrClName.c_str(),
std::forward<Args>(args)...)
.def(nanobind::init<const llvm::SmallVector<Element> &>())
.def(nanobind::init_implicit<llvm::SmallVector<Element>>())
.def(nanobind::init<const SmallVec &>())
.def(nanobind::init_implicit<SmallVec>())
.def("__len__", [](const ArrayRef &v) { return v.size(); })
.def("__bool__", [](const ArrayRef &v) { return !v.empty(); })
.def("__repr__",
Expand All @@ -71,39 +94,29 @@ bind_array_ref(nanobind::handle scope, Args &&...args) {

arrayRef.def_static("__class_getitem__",
[cl](nanobind::type_object_t<Element>) { return cl; });
arrayRef.def(nanobind::new_([](const llvm::SmallVector<Element> &sv) {
return llvm::ArrayRef<Element>(sv);
}));
arrayRef.def(nanobind::new_([](const SmallVec &sv) { return ArrayRef(sv); }));

if constexpr (nanobind::detail::is_equality_comparable_v<Element>) {
cl.def(nanobind::self == nanobind::self,
nanobind::sig("def __eq__(self, arg: object, /) -> bool"))
.def(nanobind::self != nanobind::self,
nanobind::sig("def __ne__(self, arg: object, /) -> bool"))

.def("__contains__",
[](const ArrayRef &v, const Element &x) {
return std::find(v.begin(), v.end(), x) != v.end();
})

.def("__contains__", // fallback for incompatible types
.def("__contains__",
[](const ArrayRef &, nanobind::handle) { return false; })

.def(
"count",
[](const ArrayRef &v, const Element &x) {
return std::count(v.begin(), v.end(), x);
},
"Return number of occurrences of `arg`.");
.def("count", [](const ArrayRef &v, const Element &x) {
return std::count(v.begin(), v.end(), x);
});
}

using MutableArrayRef = llvm::MutableArrayRef<Element>;
auto mutableArrClName =
"MutableArrayRef[" + std::string(typeid(Element).name()) + "]";
std::string mutableArrClName = "MutableArrayRef[" + typename_ + "]";
auto mutableCl =
nanobind::class_<MutableArrayRef>(scope, arrClName.c_str(),
std::forward<Args>(args)...)
.def(nanobind::init<llvm::SmallVector<Element> &>())
.def(nanobind::init<SmallVec &>())
.def("__len__", [](const MutableArrayRef &v) { return v.size(); })
.def("__bool__", [](const MutableArrayRef &v) { return !v.empty(); })
.def("__repr__",
Expand Down Expand Up @@ -136,21 +149,15 @@ bind_array_ref(nanobind::handle scope, Args &&...args) {
nanobind::sig("def __eq__(self, arg: object, /) -> bool"))
.def(nanobind::self != nanobind::self,
nanobind::sig("def __ne__(self, arg: object, /) -> bool"))

.def("__contains__",
[](const MutableArrayRef &v, const Element &x) {
return std::find(v.begin(), v.end(), x) != v.end();
})

.def("__contains__", // fallback for incompatible types
.def("__contains__",
[](const MutableArrayRef &, nanobind::handle) { return false; })

.def(
"count",
[](const MutableArrayRef &v, const Element &x) {
return std::count(v.begin(), v.end(), x);
},
"Return number of occurrences of `arg`.");
.def("count", [](const MutableArrayRef &v, const Element &x) {
return std::count(v.begin(), v.end(), x);
});
}

return {_smallVectorOfElement, cl, mutableCl};
Expand All @@ -162,17 +169,13 @@ template <typename Vector,
nanobind::class_<Vector> bind_iter_like(nanobind::handle scope,
const char *name, Args &&...args) {
nanobind::handle cl_cur = nanobind::type<Vector>();
if (cl_cur.is_valid()) {
// Binding already exists, don't re-create
if (cl_cur.is_valid())
return nanobind::borrow<nanobind::class_<Vector>>(cl_cur);
}

auto cl =
nanobind::class_<Vector>(scope, name, std::forward<Args>(args)...)
.def("__len__", [](const Vector &v) -> int { return v.size(); })
.def(
"__bool__", [](const Vector &v) { return !v.empty(); },
"Check whether the vector is nonempty")
.def("__bool__", [](const Vector &v) { return !v.empty(); })
.def("__repr__",
[](nanobind::handle_t<Vector> h) {
return nanobind::steal<nanobind::str>(
Expand Down Expand Up @@ -212,14 +215,11 @@ nanobind::class_<Vector> bind_iter_like(nanobind::handle scope,
[](const Vector &v, const Value &x) {
return std::find(v.begin(), v.end(), x) != v.end();
})
.def("__contains__", // fallback for incompatible types
.def("__contains__",
[](const Vector &, nanobind::handle) { return false; })
.def(
"count",
[](const Vector &v, const Value &x) {
return std::count(v.begin(), v.end(), x);
},
"Return number of occurrences of `arg`.");
.def("count", [](const Vector &v, const Value &x) {
return std::count(v.begin(), v.end(), x);
});
}

return cl;
Expand All @@ -231,17 +231,13 @@ template <typename Vector, typename ValueRef,
nanobind::class_<Vector> bind_iter_range(nanobind::handle scope,
const char *name, Args &&...args) {
nanobind::handle cl_cur = nanobind::type<Vector>();
if (cl_cur.is_valid()) {
// Binding already exists, don't re-create
if (cl_cur.is_valid())
return nanobind::borrow<nanobind::class_<Vector>>(cl_cur);
}

auto cl =
nanobind::class_<Vector>(scope, name, std::forward<Args>(args)...)
.def("__len__", [](const Vector &v) -> int { return v.size(); })
.def(
"__bool__", [](const Vector &v) { return !v.empty(); },
"Check whether the vector is nonempty")
.def("__bool__", [](const Vector &v) { return !v.empty(); })
.def("__repr__",
[](nanobind::handle_t<Vector> h) {
return nanobind::steal<nanobind::str>(
Expand Down Expand Up @@ -280,12 +276,9 @@ nanobind::class_<Vector> bind_iter_range(nanobind::handle scope,
})
.def("__contains__", // fallback for incompatible types
[](const Vector &, nanobind::handle) { return false; })
.def(
"count",
[](const Vector &v, const Value &x) {
return std::count(v.begin(), v.end(), x);
},
"Return number of occurrences of `arg`.");
.def("count", [](const Vector &v, const Value &x) {
return std::count(v.begin(), v.end(), x);
});
}

return cl;
Expand Down
26 changes: 7 additions & 19 deletions projects/eudsl-py/src/eudslpy-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,8 @@ static bool emitClass(clang::CXXRecordDecl *decl, clang::CompilerInstance &ci,
if (decl->isTemplated()) {
clang::DiagnosticBuilder builder = ci.getDiagnostics().Report(
decl->getLocation(), ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
clang::DiagnosticsEngine::Note,
"template classes not supported yet"));
// have to force emit because after the fatal error, no more warnings will
// be emitted
// https://github.com/llvm/llvm-project/blob/d74214cc8c03159e5d1f1168a09368cf3b23fd5f/clang/lib/Basic/DiagnosticIDs.cpp#L796
(void)builder.setForceEmit();
return false;
}

Expand All @@ -320,9 +316,8 @@ static bool emitClass(clang::CXXRecordDecl *decl, clang::CompilerInstance &ci,
if (decl->getNumBases() > 1) {
clang::DiagnosticBuilder builder = ci.getDiagnostics().Report(
decl->getLocation(), ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
clang::DiagnosticsEngine::Note,
"multiple base classes not supported"));
(void)builder.setForceEmit();
} else if (decl->getNumBases() == 1) {
// handle some known bases that we've already found a wap to bind
clang::CXXBaseSpecifier baseClass = *decl->bases_begin();
Expand Down Expand Up @@ -351,10 +346,9 @@ static bool emitClass(clang::CXXRecordDecl *decl, clang::CompilerInstance &ci,
"expected class template specialization");
clang::DiagnosticBuilder builder = ci.getDiagnostics().Report(
baseClass.getBeginLoc(), ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
clang::DiagnosticsEngine::Note,
"unknown base templated base class: "));
builder << baseName << "\n";
(void)builder.setForceEmit();
}
}

Expand Down Expand Up @@ -494,9 +488,8 @@ struct BindingsVisitor
if (decl->isAnonymousStructOrUnion()) {
clang::DiagnosticBuilder builder = ci.getDiagnostics().Report(
decl->getLocation(), ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
clang::DiagnosticsEngine::Note,
"anon structs/union fields not supported"));
(void)builder.setForceEmit();
return true;
}
if (decl->isBitField())
Expand Down Expand Up @@ -533,17 +526,15 @@ struct BindingsVisitor
decl->isFunctionTemplateSpecialization()) {
clang::DiagnosticBuilder builder = ci.getDiagnostics().Report(
decl->getLocation(), ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
clang::DiagnosticsEngine::Note,
"template methods not supported yet"));
(void)builder.setForceEmit();
return true;
}
if (decl->getFriendObjectKind()) {
clang::DiagnosticBuilder builder = ci.getDiagnostics().Report(
decl->getLocation(), ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
clang::DiagnosticsEngine::Note,
"friend functions not supported"));
(void)builder.setForceEmit();
return true;
}
emitClassMethodOrFunction(decl, ci, outputFile);
Expand All @@ -565,9 +556,8 @@ struct BindingsVisitor
decl->isFunctionTemplateSpecialization()) {
clang::DiagnosticBuilder builder = ci.getDiagnostics().Report(
decl->getLocation(), ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
clang::DiagnosticsEngine::Note,
"template functions not supported yet"));
(void)builder.setForceEmit();
return true;
}
emitClassMethodOrFunction(decl, ci, outputFile);
Expand Down Expand Up @@ -646,9 +636,7 @@ struct GenerateBindingsAction : clang::ASTFrontendAction {
std::unique_ptr<clang::ASTConsumer>
CreateASTConsumer(clang::CompilerInstance &compiler,
llvm::StringRef inFile) override {
// compiler.getPreprocessor().SetSuppressIncludeNotFoundError(true);
compiler.getDiagnosticOpts().ShowLevel = true;
compiler.getDiagnosticOpts().IgnoreWarnings = false;
return std::make_unique<ClassStructEnumConsumer>(compiler, outputFile);
}

Expand Down
17 changes: 7 additions & 10 deletions projects/eudsl-py/src/eudslpy_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ NB_MODULE(eudslpy_ext, m) {
nb::class_<mlir::TypeID>(m, "TypeID");
nb::class_<mlir::detail::InterfaceMap>(m, "InterfaceMap");

auto irModule = m.def_submodule("ir");
populateIRModule(irModule);

nb::class_<llvm::FailureOr<bool>>(m, "FailureOr[bool]");
nb::class_<llvm::FailureOr<mlir::StringAttr>>(m, "FailureOr[StringAttr]");
nb::class_<llvm::FailureOr<mlir::AsmResourceBlob>>(
Expand Down Expand Up @@ -436,8 +439,6 @@ NB_MODULE(eudslpy_ext, m) {
bind_array_ref<char>(m);
auto [smallVectorOfDouble, arrayRefOfDouble, mutableArrayRefOfDouble] =
bind_array_ref<double>(m);
auto [smallVectorOfLong, arrayRefOfLong, mutableArrayRefOfLong] =
bind_array_ref<long>(m);

auto [smallVectorOfInt16, arrayRefOfInt16, mutableArrayRefOfInt16] =
bind_array_ref<int16_t>(m);
Expand Down Expand Up @@ -490,8 +491,8 @@ NB_MODULE(eudslpy_ext, m) {
[smallVectorOfBool, smallVectorOfInt, smallVectorOfFloat,
smallVectorOfInt16, smallVectorOfInt32, smallVectorOfInt64,
smallVectorOfUInt16, smallVectorOfUInt32, smallVectorOfUInt64,
smallVectorOfChar, smallVectorOfDouble,
smallVectorOfLong](nb::type_object type) -> nb::object {
smallVectorOfChar,
smallVectorOfDouble](nb::type_object type) -> nb::object {
PyTypeObject *typeObj = (PyTypeObject *)type.ptr();
nb::print(type);
if (typeObj == &PyBool_Type)
Expand Down Expand Up @@ -538,16 +539,14 @@ NB_MODULE(eudslpy_ext, m) {
"__class_getitem__",
[smallVectorOfFloat, smallVectorOfInt16, smallVectorOfInt32,
smallVectorOfInt64, smallVectorOfUInt16, smallVectorOfUInt32,
smallVectorOfUInt64, smallVectorOfChar, smallVectorOfDouble,
smallVectorOfLong](std::string type) -> nb::object {
smallVectorOfUInt64, smallVectorOfChar,
smallVectorOfDouble](std::string type) -> nb::object {
if (type == "char")
return smallVectorOfChar;
if (type == "float")
return smallVectorOfFloat;
if (type == "double")
return smallVectorOfDouble;
if (type == "long")
return smallVectorOfLong;
if (type == "int16")
return smallVectorOfInt16;
if (type == "int32")
Expand Down Expand Up @@ -589,8 +588,6 @@ NB_MODULE(eudslpy_ext, m) {
bind_iter_like<llvm::iplist<mlir::Operation>,
nb::rv_policy::reference_internal>(m, "iplist[Operation]");

auto irModule = m.def_submodule("ir");
populateIRModule(irModule);
auto dialectsModule = m.def_submodule("dialects");

// auto accModule = dialectsModule.def_submodule("acc");
Expand Down
5 changes: 1 addition & 4 deletions projects/eudsl-py/src/type_casters.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Copyright (c) 2024.

#ifndef TYPE_CASTERS_H
#define TYPE_CASTERS_H
#pragma once

#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
Expand Down Expand Up @@ -87,5 +86,3 @@ struct nanobind::detail::type_caster<llvm::Twine> {
return PyUnicode_FromStringAndSize(s.data(), s.size());
}
};

#endif // TYPE_CASTERS_H

0 comments on commit 5e2252f

Please sign in to comment.