Fix eval in trace bugs (#612)

* Fix eval in trace bugs

* comment nit
This commit is contained in:
Awni Hannun 2024-02-02 09:57:12 -08:00 committed by GitHub
parent 506d43035c
commit cb6156d35d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 12 deletions

View File

@ -180,7 +180,7 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
@ -197,7 +197,7 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}

View File

@ -63,7 +63,15 @@ std::function<void()> make_task(
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
auto outputs = arr.outputs();
arr.primitive().eval_gpu(arr.inputs(), outputs);
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());

View File

@ -1779,10 +1779,6 @@ array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
inputs);
}
array operator&&(const array& a, const array& b) {
// check if a and b are bool arrays
if (a.dtype() != bool_ || b.dtype() != bool_) {
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
}
return logical_and(a, b);
}
@ -1797,11 +1793,6 @@ array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
inputs);
}
array operator||(const array& a, const array& b) {
// check if a and b are bool arrays
if (a.dtype() != bool_ || b.dtype() != bool_) {
throw std::invalid_argument(
"[operator||] is only supported for bool arrays.");
}
return logical_or(a, b);
}

View File

@ -393,6 +393,28 @@ class TestAutograd(mlx_tests.MLXTestCase):
g = mx.grad(lambda x: x**2)(x)
self.assertAlmostEqual(g.item(), 4.0)
def test_eval_in_grad(self):
arr = mx.array([1.0])
cotan = mx.array([1.0, 1.0])
y = mx.array([2.0, 2.0])
def func(x):
x = x + y
cond = x < 1
cond.tolist()
return x**2
_, vjps = mx.vjp(func, (arr,), (cotan,))
self.assertEqual(vjps[0].item(), 12.0)
def func(x):
x = x + mx.array([1.0, 1.0])
mx.eval(x)
return x**2
_, vjps = mx.vjp(func, (arr,), (cotan,))
self.assertEqual(vjps[0].item(), 8.0)
if __name__ == "__main__":
unittest.main()