mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
506d43035c
commit
cb6156d35d
@ -180,7 +180,7 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
primitive(std::move(primitive)),
|
primitive(std::move(primitive)),
|
||||||
inputs(inputs) {
|
inputs(inputs) {
|
||||||
std::tie(size, strides) = cum_prod(this->shape);
|
std::tie(size, strides) = cum_prod(this->shape);
|
||||||
for (auto& in : inputs) {
|
for (auto& in : this->inputs) {
|
||||||
is_tracer |= in.is_tracer();
|
is_tracer |= in.is_tracer();
|
||||||
depth = std::max(in.graph_depth(), depth);
|
depth = std::max(in.graph_depth(), depth);
|
||||||
}
|
}
|
||||||
@ -197,7 +197,7 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
primitive(std::move(primitive)),
|
primitive(std::move(primitive)),
|
||||||
inputs(std::move(inputs)) {
|
inputs(std::move(inputs)) {
|
||||||
std::tie(size, strides) = cum_prod(this->shape);
|
std::tie(size, strides) = cum_prod(this->shape);
|
||||||
for (auto& in : inputs) {
|
for (auto& in : this->inputs) {
|
||||||
is_tracer |= in.is_tracer();
|
is_tracer |= in.is_tracer();
|
||||||
depth = std::max(in.graph_depth(), depth);
|
depth = std::max(in.graph_depth(), depth);
|
||||||
}
|
}
|
||||||
|
@ -63,7 +63,15 @@ std::function<void()> make_task(
|
|||||||
auto s = arr.primitive().stream();
|
auto s = arr.primitive().stream();
|
||||||
auto command_buffer = increment_command_buffer(s);
|
auto command_buffer = increment_command_buffer(s);
|
||||||
auto outputs = arr.outputs();
|
auto outputs = arr.outputs();
|
||||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
{
|
||||||
|
// If the array is a tracer hold a reference
|
||||||
|
// to its inputs so they don't get donated
|
||||||
|
std::vector<array> inputs;
|
||||||
|
if (arr.is_tracer()) {
|
||||||
|
inputs = arr.inputs();
|
||||||
|
}
|
||||||
|
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||||
|
}
|
||||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
buffers.push_back(in.data_shared_ptr());
|
buffers.push_back(in.data_shared_ptr());
|
||||||
|
@ -1779,10 +1779,6 @@ array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
inputs);
|
inputs);
|
||||||
}
|
}
|
||||||
array operator&&(const array& a, const array& b) {
|
array operator&&(const array& a, const array& b) {
|
||||||
// check if a and b are bool arrays
|
|
||||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
|
||||||
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
|
|
||||||
}
|
|
||||||
return logical_and(a, b);
|
return logical_and(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1797,11 +1793,6 @@ array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
inputs);
|
inputs);
|
||||||
}
|
}
|
||||||
array operator||(const array& a, const array& b) {
|
array operator||(const array& a, const array& b) {
|
||||||
// check if a and b are bool arrays
|
|
||||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[operator||] is only supported for bool arrays.");
|
|
||||||
}
|
|
||||||
return logical_or(a, b);
|
return logical_or(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -393,6 +393,28 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
g = mx.grad(lambda x: x**2)(x)
|
g = mx.grad(lambda x: x**2)(x)
|
||||||
self.assertAlmostEqual(g.item(), 4.0)
|
self.assertAlmostEqual(g.item(), 4.0)
|
||||||
|
|
||||||
|
def test_eval_in_grad(self):
|
||||||
|
arr = mx.array([1.0])
|
||||||
|
cotan = mx.array([1.0, 1.0])
|
||||||
|
y = mx.array([2.0, 2.0])
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
x = x + y
|
||||||
|
cond = x < 1
|
||||||
|
cond.tolist()
|
||||||
|
return x**2
|
||||||
|
|
||||||
|
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||||
|
self.assertEqual(vjps[0].item(), 12.0)
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
x = x + mx.array([1.0, 1.0])
|
||||||
|
mx.eval(x)
|
||||||
|
return x**2
|
||||||
|
|
||||||
|
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||||
|
self.assertEqual(vjps[0].item(), 8.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user