Added Kronecker Product (#1728)

This commit is contained in:
Venkata Naga Aditya Datta Chivukula
2025-01-02 17:00:34 -07:00
committed by GitHub
parent 92ec632ad5
commit 491fa95b1f
4 changed files with 88 additions and 0 deletions

View File

@@ -1458,6 +1458,35 @@ void init_ops(nb::module_& m) {
Returns:
array: The range of values.
)pbdoc");
m.def(
"kron",
&kron,
nb::arg("a"),
nb::arg("b"),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the Kronecker product of two arrays `a` and `b`.
Args:
a (array): The first input array
b (array): The second input array
stream (Union[None, Stream, Device], optional): Optional stream or device for execution.
Default is `None`.
Returns:
array: The Kronecker product of `a` and `b`.
Examples:
>>> import mlx
>>> a = mlx.array([[1, 2], [3, 4]])
>>> b = mlx.array([[0, 5], [6, 7]])
>>> result = mlx.kron(a, b)
>>> print(result)
[[ 0 5 0 10]
[ 6 7 12 14]
[ 0 15 0 20]
[18 21 24 28]]
)pbdoc");
m.def(
"take",
[](const mx::array& a,