mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-04 02:28:13 +08:00
basic python tests
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
@@ -163,6 +162,19 @@ py::object tree_unflatten(
|
||||
});
|
||||
}
|
||||
|
||||
py::object tree_unflatten_none(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto validate_argnums_argnames(
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
@@ -438,30 +450,36 @@ auto py_vmap(
|
||||
}
|
||||
|
||||
auto py_compile(const py::function& fun) {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<size_t, py::object> tree_cache;
|
||||
|
||||
return [fun](const py::args& args) {
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
|
||||
// py_value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_outputs;
|
||||
|
||||
auto compile_fun =
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py_outputs = fun(*tree_unflatten(args, a));
|
||||
|
||||
// Flatten the outputs
|
||||
return tree_flatten(py_outputs, true);
|
||||
};
|
||||
|
||||
// Compile and call
|
||||
// TODO, awni, I think this cast is ok??
|
||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||
|
||||
auto compile_fun = [fun_id, &fun, &args, &inputs](
|
||||
const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py::object py_outputs = fun(*tree_unflatten(args, a));
|
||||
|
||||
// Flatten the outputs
|
||||
auto outputs = tree_flatten(py_outputs, true);
|
||||
|
||||
py_outputs =
|
||||
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||
tree_cache.insert({fun_id, py_outputs});
|
||||
return outputs;
|
||||
};
|
||||
|
||||
// Compile and call
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
|
||||
// Put the outputs back in the container
|
||||
return tree_unflatten(py_outputs, outputs);
|
||||
py::object py_outputs = tree_cache.at(fun_id);
|
||||
return tree_unflatten_none(py_outputs, outputs);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,12 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
compiled_fn = mx.compile(fun)
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
# out = compiled_fn(x, y)
|
||||
out = compiled_fn(x, y)
|
||||
self.assertEqual(out.item(), 2.0)
|
||||
|
||||
# Try again
|
||||
out = compiled_fn(x, y)
|
||||
self.assertEqual(out.item(), 2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user