From cb6156d35d176aa4e27d1ba26edd6e71067a2fc3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 2 Feb 2024 09:57:12 -0800 Subject: [PATCH] Fix eval in trace bugs (#612) * Fix eval in trace bugs * comment nit --- mlx/array.cpp | 4 ++-- mlx/backend/metal/metal.cpp | 10 +++++++++- mlx/ops.cpp | 9 --------- python/tests/test_autograd.py | 22 ++++++++++++++++++++++ 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index 59902c86d..7f3dd854b 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -180,7 +180,7 @@ array::ArrayDesc::ArrayDesc( primitive(std::move(primitive)), inputs(inputs) { std::tie(size, strides) = cum_prod(this->shape); - for (auto& in : inputs) { + for (auto& in : this->inputs) { is_tracer |= in.is_tracer(); depth = std::max(in.graph_depth(), depth); } @@ -197,7 +197,7 @@ array::ArrayDesc::ArrayDesc( primitive(std::move(primitive)), inputs(std::move(inputs)) { std::tie(size, strides) = cum_prod(this->shape); - for (auto& in : inputs) { + for (auto& in : this->inputs) { is_tracer |= in.is_tracer(); depth = std::max(in.graph_depth(), depth); } diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 372b3d231..96cf87818 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -63,7 +63,15 @@ std::function make_task( auto s = arr.primitive().stream(); auto command_buffer = increment_command_buffer(s); auto outputs = arr.outputs(); - arr.primitive().eval_gpu(arr.inputs(), outputs); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + arr.primitive().eval_gpu(arr.inputs(), outputs); + } std::vector> buffers; for (auto& in : arr.inputs()) { buffers.push_back(in.data_shared_ptr()); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d66790a59..cb15a9570 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1779,10 +1779,6 @@ array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) { inputs); } array operator&&(const array& a, const array& b) { - // check if a and b are bool arrays - if (a.dtype() != bool_ || b.dtype() != bool_) { - throw std::invalid_argument("[operator&&] only supported for bool arrays."); - } return logical_and(a, b); } @@ -1797,11 +1793,6 @@ array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) { inputs); } array operator||(const array& a, const array& b) { - // check if a and b are bool arrays - if (a.dtype() != bool_ || b.dtype() != bool_) { - throw std::invalid_argument( - "[operator||] is only supported for bool arrays."); - } return logical_or(a, b); } diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 78f7346a8..6c5c922b1 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -393,6 +393,28 @@ class TestAutograd(mlx_tests.MLXTestCase): g = mx.grad(lambda x: x**2)(x) self.assertAlmostEqual(g.item(), 4.0) + def test_eval_in_grad(self): + arr = mx.array([1.0]) + cotan = mx.array([1.0, 1.0]) + y = mx.array([2.0, 2.0]) + + def func(x): + x = x + y + cond = x < 1 + cond.tolist() + return x**2 + + _, vjps = mx.vjp(func, (arr,), (cotan,)) + self.assertEqual(vjps[0].item(), 12.0) + + def func(x): + x = x + mx.array([1.0, 1.0]) + mx.eval(x) + return x**2 + + _, vjps = mx.vjp(func, (arr,), (cotan,)) + self.assertEqual(vjps[0].item(), 8.0) + if __name__ == "__main__": unittest.main()