expose depends (#2606)

This commit is contained in:
Awni Hannun
2025-09-18 10:06:15 -07:00
committed by GitHub
parent 3f730e77aa
commit 50cc09887f
2 changed files with 53 additions and 0 deletions

View File

@@ -5319,4 +5319,44 @@ void init_ops(nb::module_& m) {
>>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1))
(5, 3, 4)
)pbdoc");
m.def(
"depends",
[](const nb::object& inputs_, const nb::object& deps_) {
bool return_vec = false;
std::vector<mx::array> inputs;
std::vector<mx::array> deps;
if (nb::isinstance<mx::array>(inputs_)) {
inputs = {nb::cast<mx::array>(inputs_)};
} else {
return_vec = true;
inputs = {nb::cast<std::vector<mx::array>>(inputs_)};
}
if (nb::isinstance<mx::array>(deps_)) {
deps = {nb::cast<mx::array>(deps_)};
} else {
deps = {nb::cast<std::vector<mx::array>>(deps_)};
}
auto out = depends(inputs, deps);
if (return_vec) {
return nb::cast(out);
} else {
return nb::cast(out[0]);
}
},
nb::arg(),
nb::arg(),
nb::sig(
"def depends(inputs: Union[array, Sequence[array]], dependencies: Union[array, Sequence[array]])"),
R"pbdoc(
Insert dependencies between arrays in the graph. The outputs are
identical to ``inputs`` but with dependencies on ``dependencies``.
Args:
inputs (array or Sequence[array]): The input array or arrays.
dependencies (array or Sequence[array]): The array or arrays
to insert dependencies on.
Returns:
array or Sequence[array]: The outputs which depend on dependencies.
)pbdoc");
}