mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous * style * nits in docs * a bunch more stuff * update jit * update jit * use constant for shapes and strides and remove elem_to_loc overload * use kernel instantiation * docs nits * update binary and ternary * comments
This commit is contained in:
@@ -243,7 +243,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
template_args.emplace_back(name, dtype);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -271,25 +271,24 @@ void init_fast(nb::module_& parent_module) {
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
Run the kernel.
|
||||
|
||||
Args:
|
||||
inputs (List[array]): The inputs passed to the Metal kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
Args:
|
||||
inputs (List[array]): The inputs passed to the Metal kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
List[array]: The list of output arrays.
|
||||
)pbdoc");
|
||||
Returns:
|
||||
List[array]: The list of output arrays.)pbdoc");
|
||||
},
|
||||
"name"_a,
|
||||
"input_names"_a,
|
||||
@@ -306,16 +305,16 @@ void init_fast(nb::module_& parent_module) {
|
||||
input_names (List[str]): The parameter names of the inputs in the
|
||||
function signature.
|
||||
output_names (List[str]): The parameter names of the outputs in the
|
||||
function signature.
|
||||
function signature.
|
||||
source (str): Source code. This is the body of a function in Metal,
|
||||
the function signature will be automatically generated.
|
||||
the function signature will be automatically generated.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of
|
||||
the main function body.
|
||||
Useful for helper functions or includes that should live outside of
|
||||
the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
|
||||
Returns:
|
||||
Callable ``metal_kernel``.
|
||||
|
Reference in New Issue
Block a user