mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 07:18:15 +08:00
27
python/mlx/__main__.py
Normal file
27
python/mlx/__main__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def main() -> None:
|
||||
from mlx.core import __version__
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
version=__version__,
|
||||
help="Print the version number.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cmake-dir",
|
||||
action="store_true",
|
||||
help="Print the path to the MLX CMake module directory.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.cmake_dir:
|
||||
from pathlib import Path
|
||||
|
||||
print(Path(__file__).parent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1468,24 +1468,26 @@ void init_ops(nb::module_& m) {
|
||||
nb::sig(
|
||||
"def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the Kronecker product of two arrays `a` and `b`.
|
||||
Compute the Kronecker product of two arrays ``a`` and ``b``.
|
||||
|
||||
Args:
|
||||
a (array): The first input array
|
||||
b (array): The second input array
|
||||
stream (Union[None, Stream, Device], optional): Optional stream or device for execution.
|
||||
Default is `None`.
|
||||
a (array): The first input array.
|
||||
b (array): The second input array.
|
||||
stream (Union[None, Stream, Device], optional): Optional stream or
|
||||
device for execution. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The Kronecker product of `a` and `b`.
|
||||
array: The Kronecker product of ``a`` and ``b``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx
|
||||
>>> a = mlx.array([[1, 2], [3, 4]])
|
||||
>>> b = mlx.array([[0, 5], [6, 7]])
|
||||
>>> result = mlx.kron(a, b)
|
||||
>>> a = mx.array([[1, 2], [3, 4]])
|
||||
>>> b = mx.array([[0, 5], [6, 7]])
|
||||
>>> result = mx.kron(a, b)
|
||||
>>> print(result)
|
||||
[[ 0 5 0 10]
|
||||
[ 6 7 12 14]
|
||||
[ 0 15 0 20]
|
||||
[18 21 24 28]]
|
||||
array([[0, 5, 0, 10],
|
||||
[6, 7, 12, 14],
|
||||
[0, 15, 0, 20],
|
||||
[18, 21, 24, 28]], dtype=int32)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"take",
|
||||
|
||||
Reference in New Issue
Block a user