diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 74909e81d..27d330d6e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5319,4 +5319,44 @@ void init_ops(nb::module_& m) { >>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1)) (5, 3, 4) )pbdoc"); + m.def( + "depends", + [](const nb::object& inputs_, const nb::object& deps_) { + bool return_vec = false; + std::vector inputs; + std::vector deps; + if (nb::isinstance(inputs_)) { + inputs = {nb::cast(inputs_)}; + } else { + return_vec = true; + inputs = {nb::cast>(inputs_)}; + } + if (nb::isinstance(deps_)) { + deps = {nb::cast(deps_)}; + } else { + deps = {nb::cast>(deps_)}; + } + auto out = depends(inputs, deps); + if (return_vec) { + return nb::cast(out); + } else { + return nb::cast(out[0]); + } + }, + nb::arg(), + nb::arg(), + nb::sig( + "def depends(inputs: Union[array, Sequence[array]], dependencies: Union[array, Sequence[array]])"), + R"pbdoc( + Insert dependencies between arrays in the graph. The outputs are + identical to ``inputs`` but with dependencies on ``dependencies``. + + Args: + inputs (array or Sequence[array]): The input array or arrays. + dependencies (array or Sequence[array]): The array or arrays + to insert dependencies on. + + Returns: + array or Sequence[array]: The outputs which depend on dependencies. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index a915dfe36..e60952aa7 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3081,6 +3081,19 @@ class TestOps(mlx_tests.MLXTestCase): # Doesn't hang x = mx.power(2, -1) + def test_depends(self): + a = mx.array([1.0, 2.0, 3.0]) + b = mx.exp(a) + c = mx.log(a) + out = mx.depends([b], [c])[0] + self.assertTrue(mx.array_equal(out, b)) + + a = mx.array([1.0, 2.0, 3.0]) + b = mx.exp(a) + c = mx.log(a) + out = mx.depends(b, c) + self.assertTrue(mx.array_equal(out, b)) + class TestBroadcast(mlx_tests.MLXTestCase): def test_broadcast_shapes(self):