mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
506d43035c
commit
cb6156d35d
@ -180,7 +180,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
primitive(std::move(primitive)),
|
||||
inputs(inputs) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : inputs) {
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
@ -197,7 +197,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
primitive(std::move(primitive)),
|
||||
inputs(std::move(inputs)) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : inputs) {
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
|
@ -63,7 +63,15 @@ std::function<void()> make_task(
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.push_back(in.data_shared_ptr());
|
||||
|
@ -1779,10 +1779,6 @@ array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
inputs);
|
||||
}
|
||||
array operator&&(const array& a, const array& b) {
|
||||
// check if a and b are bool arrays
|
||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
|
||||
}
|
||||
return logical_and(a, b);
|
||||
}
|
||||
|
||||
@ -1797,11 +1793,6 @@ array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
inputs);
|
||||
}
|
||||
array operator||(const array& a, const array& b) {
|
||||
// check if a and b are bool arrays
|
||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"[operator||] is only supported for bool arrays.");
|
||||
}
|
||||
return logical_or(a, b);
|
||||
}
|
||||
|
||||
|
@ -393,6 +393,28 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
g = mx.grad(lambda x: x**2)(x)
|
||||
self.assertAlmostEqual(g.item(), 4.0)
|
||||
|
||||
def test_eval_in_grad(self):
|
||||
arr = mx.array([1.0])
|
||||
cotan = mx.array([1.0, 1.0])
|
||||
y = mx.array([2.0, 2.0])
|
||||
|
||||
def func(x):
|
||||
x = x + y
|
||||
cond = x < 1
|
||||
cond.tolist()
|
||||
return x**2
|
||||
|
||||
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||
self.assertEqual(vjps[0].item(), 12.0)
|
||||
|
||||
def func(x):
|
||||
x = x + mx.array([1.0, 1.0])
|
||||
mx.eval(x)
|
||||
return x**2
|
||||
|
||||
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||
self.assertEqual(vjps[0].item(), 8.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user