mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix for captured state
This commit is contained in:
@@ -395,7 +395,16 @@ struct PyCompiledFun {
|
|||||||
nb::object captured_inputs;
|
nb::object captured_inputs;
|
||||||
nb::object captured_outputs;
|
nb::object captured_outputs;
|
||||||
bool shapeless;
|
bool shapeless;
|
||||||
mutable size_t num_outputs{0};
|
|
||||||
|
// Data to attach to the compiled function that contains the python output
|
||||||
|
// structure and the number of arrays in said structure.
|
||||||
|
struct AttachedData {
|
||||||
|
nb::object output_structure;
|
||||||
|
int num_outputs;
|
||||||
|
|
||||||
|
AttachedData(nb::object output_structure_, int num_outputs_)
|
||||||
|
: output_structure(output_structure_), num_outputs(num_outputs_) {}
|
||||||
|
};
|
||||||
|
|
||||||
PyCompiledFun(
|
PyCompiledFun(
|
||||||
const nb::callable& fun,
|
const nb::callable& fun,
|
||||||
@@ -418,7 +427,6 @@ struct PyCompiledFun {
|
|||||||
captured_inputs = std::move(other.captured_inputs);
|
captured_inputs = std::move(other.captured_inputs);
|
||||||
captured_outputs = std::move(other.captured_outputs);
|
captured_outputs = std::move(other.captured_outputs);
|
||||||
shapeless = other.shapeless;
|
shapeless = other.shapeless;
|
||||||
num_outputs = other.num_outputs;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||||
@@ -502,12 +510,9 @@ struct PyCompiledFun {
|
|||||||
auto [outputs, py_outputs] =
|
auto [outputs, py_outputs] =
|
||||||
tree_flatten_with_structure(std::move(tree_outputs), false);
|
tree_flatten_with_structure(std::move(tree_outputs), false);
|
||||||
|
|
||||||
std::shared_ptr<void> py_outputs_void(
|
std::shared_ptr<void> extra_data =
|
||||||
py_outputs.release().ptr(), [](void* handle) {
|
std::make_shared<AttachedData>(py_outputs, outputs.size());
|
||||||
nb::steal(reinterpret_cast<PyObject*>(handle)).reset();
|
|
||||||
});
|
|
||||||
|
|
||||||
num_outputs = outputs.size();
|
|
||||||
if (!captured_outputs.is_none()) {
|
if (!captured_outputs.is_none()) {
|
||||||
auto flat_out_captures = tree_flatten(captured_outputs, false);
|
auto flat_out_captures = tree_flatten(captured_outputs, false);
|
||||||
outputs.insert(
|
outputs.insert(
|
||||||
@@ -520,7 +525,7 @@ struct PyCompiledFun {
|
|||||||
if (!captured_inputs.is_none()) {
|
if (!captured_inputs.is_none()) {
|
||||||
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
||||||
}
|
}
|
||||||
return mx::detail::ArraysAndExtra{outputs, py_outputs_void};
|
return mx::detail::ArraysAndExtra{outputs, extra_data};
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!captured_inputs.is_none()) {
|
if (!captured_inputs.is_none()) {
|
||||||
@@ -532,8 +537,14 @@ struct PyCompiledFun {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compile and call
|
// Compile and call
|
||||||
auto [outputs, py_outputs_void] =
|
auto [outputs, extra_data] =
|
||||||
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||||
|
|
||||||
|
int num_outputs =
|
||||||
|
reinterpret_cast<AttachedData*>(extra_data.get())->num_outputs;
|
||||||
|
nb::object py_outputs =
|
||||||
|
reinterpret_cast<AttachedData*>(extra_data.get())->output_structure;
|
||||||
|
|
||||||
if (!captured_outputs.is_none()) {
|
if (!captured_outputs.is_none()) {
|
||||||
std::vector<mx::array> captures(
|
std::vector<mx::array> captures(
|
||||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||||
@@ -542,9 +553,7 @@ struct PyCompiledFun {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Put the outputs back in the container
|
// Put the outputs back in the container
|
||||||
nb::object py_outputs =
|
return tree_unflatten_from_structure(std::move(py_outputs), outputs);
|
||||||
nb::borrow(reinterpret_cast<PyObject*>(py_outputs_void.get()));
|
|
||||||
return tree_unflatten_from_structure(py_outputs, outputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
|
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
|
||||||
|
|||||||
@@ -1096,6 +1096,25 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
d = fun(False, mx.array(1.0))
|
d = fun(False, mx.array(1.0))
|
||||||
self.assertTrue(d is None)
|
self.assertTrue(d is None)
|
||||||
|
|
||||||
|
def test_compile_changing_outputs_with_state(self):
|
||||||
|
state = [mx.array(1.0)]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def fun(y):
|
||||||
|
x = state[0]
|
||||||
|
if y.dtype == mx.float32:
|
||||||
|
state[0] = 2 * y
|
||||||
|
return [x, y, x + y]
|
||||||
|
elif y.dtype == mx.int32:
|
||||||
|
state[0] *= 2
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
fun(mx.array(1.0))
|
||||||
|
fun(mx.array(1))
|
||||||
|
|
||||||
|
self.assertEqual(state[0].item(), 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user