fix wraps compile (#2461)

This commit is contained in:
Awni Hannun
2025-08-04 16:14:18 -07:00
committed by GitHub
parent 6ad0889c8a
commit 0b807893a7
4 changed files with 48 additions and 38 deletions

View File

@@ -10,15 +10,22 @@
namespace nb = nanobind;
using namespace nb::literals;
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps);
nb::callable mlx_func(
nb::object func,
const nb::callable& orig_func,
std::vector<PyObject*> deps);
template <typename F, typename... Deps>
nb::callable mlx_func(F func, Deps&&... deps) {
nb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) {
return mlx_func(
nb::cpp_function(std::move(func)), std::vector<PyObject*>{deps.ptr()...});
nb::cpp_function(std::move(func)),
orig_func,
std::vector<PyObject*>{deps.ptr()...});
}
template <typename... Deps>
nb::callable mlx_func(nb::object func, Deps&&... deps) {
return mlx_func(std::move(func), std::vector<PyObject*>{deps.ptr()...});
nb::callable
mlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) {
return mlx_func(
std::move(func), orig_func, std::vector<PyObject*>{deps.ptr()...});
}