mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Added Kronecker Product (#1728)
This commit is contained in:

committed by
GitHub

parent
92ec632ad5
commit
491fa95b1f
@@ -1458,6 +1458,35 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The range of values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"kron",
|
||||
&kron,
|
||||
nb::arg("a"),
|
||||
nb::arg("b"),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the Kronecker product of two arrays `a` and `b`.
|
||||
Args:
|
||||
a (array): The first input array
|
||||
b (array): The second input array
|
||||
stream (Union[None, Stream, Device], optional): Optional stream or device for execution.
|
||||
Default is `None`.
|
||||
Returns:
|
||||
array: The Kronecker product of `a` and `b`.
|
||||
Examples:
|
||||
>>> import mlx
|
||||
>>> a = mlx.array([[1, 2], [3, 4]])
|
||||
>>> b = mlx.array([[0, 5], [6, 7]])
|
||||
>>> result = mlx.kron(a, b)
|
||||
>>> print(result)
|
||||
[[ 0 5 0 10]
|
||||
[ 6 7 12 14]
|
||||
[ 0 15 0 20]
|
||||
[18 21 24 28]]
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"take",
|
||||
[](const mx::array& a,
|
||||
|
Reference in New Issue
Block a user