Address more comments

This commit is contained in:
Angelos Katharopoulos 2025-08-20 17:19:36 -07:00
parent d6b204b528
commit 6f608857db
3 changed files with 11 additions and 1 deletions

9
docs/src/python/cuda.rst Normal file
View File

@ -0,0 +1,9 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@ -13,3 +13,4 @@ Fast
rope
scaled_dot_product_attention
metal_kernel
cuda_kernel

View File

@ -438,7 +438,7 @@ void init_fast(nb::module_& parent_module) {
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadgroups.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.