More fixes for arrays with large sizes (#1405)

* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
This commit is contained in:
Awni Hannun
2024-09-17 12:46:31 -07:00
committed by GitHub
parent c6739ba7f3
commit 4f46e9c997
26 changed files with 325 additions and 611 deletions

View File

@@ -243,7 +243,7 @@ void init_fast(nb::module_& parent_module) {
template_args.emplace_back(name, dtype);
} else {
throw std::invalid_argument(
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
@@ -271,25 +271,24 @@ void init_fast(nb::module_& parent_module) {
nb::sig(
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Run the kernel.
Args:
inputs (List[array]): The inputs passed to the Metal kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Args:
inputs (List[array]): The inputs passed to the Metal kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
List[array]: The list of output arrays.
)pbdoc");
Returns:
List[array]: The list of output arrays.)pbdoc");
},
"name"_a,
"input_names"_a,
@@ -306,16 +305,16 @@ void init_fast(nb::module_& parent_module) {
input_names (List[str]): The parameter names of the inputs in the
function signature.
output_names (List[str]): The parameter names of the outputs in the
function signature.
function signature.
source (str): Source code. This is the body of a function in Metal,
the function signature will be automatically generated.
the function signature will be automatically generated.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of
the main function body.
Useful for helper functions or includes that should live outside of
the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
e.g. ``device atomic<float>``. Default: ``False``.
Returns:
Callable ``metal_kernel``.