Add mx.meshgrid (#961)

This commit is contained in:
Abe Leininger
2024-04-09 14:43:08 -04:00
committed by GitHub
parent ae812350f9
commit a1a31eed27
6 changed files with 161 additions and 1 deletions

View File

@@ -2568,6 +2568,35 @@ void init_ops(nb::module_& m) {
Returns:
array: The resulting stacked array.
)pbdoc");
m.def(
"meshgrid",
[](nb::args arrays_,
bool sparse,
std::string indexing,
StreamOrDevice s) {
std::vector<array> arrays = nb::cast<std::vector<array>>(arrays_);
return meshgrid(arrays, sparse, indexing, s);
},
"arrays"_a,
"sparse"_a = false,
"indexing"_a = "xy",
"stream"_a = nb::none(),
nb::sig(
"def meshgrid(*arrays: array, sparse: Optional[bool] = false, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate multidimensional coordinate grids from 1-D coordinate arrays
Args:
arrays (array): Input arrays.
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
array has a single non-zero element. If ``False``, a dense grid is returned.
Defaults to ``False``.
indexing (str, optional): Cartesian ('xy') or matrix ('ij') indexing of the output arrays.
Defaults to ``'xy'``.
Returns:
list(array): The output arrays.
)pbdoc");
m.def(
"repeat",
[](const array& array,