mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 01:51:18 +08:00
fix test
This commit is contained in:
parent
4f50935c2c
commit
6189111494
@ -477,13 +477,16 @@ auto py_compile(const py::function& fun) {
|
||||
auto inputs = tree_flatten(args, true);
|
||||
|
||||
// Get globally enclosed arrays so we don't compile through them
|
||||
if (py::hasattr(fun, "__globals__")) {
|
||||
auto global_inputs = tree_flatten(py::getattr(fun, "__globals__"), false);
|
||||
std::move(
|
||||
std::begin(global_inputs),
|
||||
std::end(global_inputs),
|
||||
std::back_inserter(inputs));
|
||||
}
|
||||
|
||||
// Get locally enclosed arrays so we don't compile through them
|
||||
if (py::hasattr(fun, "__closure__")) {
|
||||
auto closures = py::getattr(fun, "__closure__");
|
||||
if (py::isinstance<py::tuple>(closures)) {
|
||||
for (auto& closure : closures) {
|
||||
@ -495,6 +498,7 @@ auto py_compile(const py::function& fun) {
|
||||
std::back_inserter(inputs));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
|
Loading…
Reference in New Issue
Block a user