MLX in C++ example (#1736)

* MLX in C++ example

* nits

* fix docs
This commit is contained in:
Awni Hannun
2025-01-02 19:09:04 -08:00
committed by GitHub
parent 8544b42007
commit c9d30aa6ac
13 changed files with 242 additions and 31 deletions

View File

@@ -1468,24 +1468,26 @@ void init_ops(nb::module_& m) {
nb::sig(
"def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the Kronecker product of two arrays `a` and `b`.
Compute the Kronecker product of two arrays ``a`` and ``b``.
Args:
a (array): The first input array
b (array): The second input array
stream (Union[None, Stream, Device], optional): Optional stream or device for execution.
Default is `None`.
a (array): The first input array.
b (array): The second input array.
stream (Union[None, Stream, Device], optional): Optional stream or
device for execution. Default: ``None``.
Returns:
array: The Kronecker product of `a` and `b`.
array: The Kronecker product of ``a`` and ``b``.
Examples:
>>> import mlx
>>> a = mlx.array([[1, 2], [3, 4]])
>>> b = mlx.array([[0, 5], [6, 7]])
>>> result = mlx.kron(a, b)
>>> a = mx.array([[1, 2], [3, 4]])
>>> b = mx.array([[0, 5], [6, 7]])
>>> result = mx.kron(a, b)
>>> print(result)
[[ 0 5 0 10]
[ 6 7 12 14]
[ 0 15 0 20]
[18 21 24 28]]
array([[0, 5, 0, 10],
[6, 7, 12, 14],
[0, 15, 0, 20],
[18, 21, 24, 28]], dtype=int32)
)pbdoc");
m.def(
"take",