2025-02-12 06:45:02 +08:00
|
|
|
// Copyright © 2025 Apple Inc.
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
#include <nanobind/nanobind.h>
|
|
|
|
#include <nanobind/stl/function.h>
|
|
|
|
|
|
|
|
namespace nb = nanobind;
|
|
|
|
using namespace nb::literals;
|
|
|
|
|
2025-08-05 07:14:18 +08:00
|
|
|
nb::callable mlx_func(
|
|
|
|
nb::object func,
|
|
|
|
const nb::callable& orig_func,
|
|
|
|
std::vector<PyObject*> deps);
|
2025-02-12 06:45:02 +08:00
|
|
|
|
|
|
|
template <typename F, typename... Deps>
|
2025-08-05 07:14:18 +08:00
|
|
|
nb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) {
|
2025-02-12 06:45:02 +08:00
|
|
|
return mlx_func(
|
2025-08-05 07:14:18 +08:00
|
|
|
nb::cpp_function(std::move(func)),
|
|
|
|
orig_func,
|
|
|
|
std::vector<PyObject*>{deps.ptr()...});
|
2025-02-12 06:45:02 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename... Deps>
|
2025-08-05 07:14:18 +08:00
|
|
|
nb::callable
|
|
|
|
mlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) {
|
|
|
|
return mlx_func(
|
|
|
|
std::move(func), orig_func, std::vector<PyObject*>{deps.ptr()...});
|
2025-02-12 06:45:02 +08:00
|
|
|
}
|