From 6f608857dbdebc23b84538e813130125c0459541 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 20 Aug 2025 17:19:36 -0700 Subject: [PATCH] Address more comments --- docs/src/python/cuda.rst | 9 +++++++++ docs/src/python/fast.rst | 1 + python/src/fast.cpp | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 docs/src/python/cuda.rst diff --git a/docs/src/python/cuda.rst b/docs/src/python/cuda.rst new file mode 100644 index 000000000..932d36b5e --- /dev/null +++ b/docs/src/python/cuda.rst @@ -0,0 +1,9 @@ +CUDA +===== + +.. currentmodule:: mlx.core.cuda + +.. autosummary:: + :toctree: _autosummary + + is_available diff --git a/docs/src/python/fast.rst b/docs/src/python/fast.rst index f78f40563..b250dcb18 100644 --- a/docs/src/python/fast.rst +++ b/docs/src/python/fast.rst @@ -13,3 +13,4 @@ Fast rope scaled_dot_product_attention metal_kernel + cuda_kernel diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 02e924a94..12d6de358 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -438,7 +438,7 @@ void init_fast(nb::module_& parent_module) { output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. - For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks. + For compatibility with :func:`metal_kernel` the grid is in threads and not in threadgroups. threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. These will be added as template arguments to the kernel definition. Default: ``None``.