mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Add grid_sample example to metal_kernel docs (#1352)
				
					
				
			* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel` * add grid sample to docs * zero_outputs -> init_value * add missing header for linux
This commit is contained in:
		| @@ -200,10 +200,11 @@ void init_fast(nb::module_& parent_module) { | ||||
|       A jit-compiled custom Metal kernel defined from a source string. | ||||
|       )pbdoc") | ||||
|       .def( | ||||
|           nb::init<const std::string&, const std::string&, bool>(), | ||||
|           nb::init<const std::string&, const std::string&, bool, bool>(), | ||||
|           "name"_a, | ||||
|           "source"_a, | ||||
|           "ensure_row_contiguous"_a = true, | ||||
|           "atomic_outputs"_a = false, | ||||
|           R"pbdoc( | ||||
|       Initialize a metal_kernel. | ||||
|  | ||||
| @@ -215,6 +216,8 @@ void init_fast(nb::module_& parent_module) { | ||||
|             used when the kernel is called. | ||||
|         ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous | ||||
|             before the kernel runs. Default: ``True``. | ||||
|         atomic_outputs (bool): Whether to use atomic outputs in the function signature | ||||
|             e.g. ``device atomic<float>``. Default: ``False``. | ||||
|       Returns: | ||||
|         Callable ``metal_kernel``. | ||||
|  | ||||
| @@ -256,6 +259,7 @@ void init_fast(nb::module_& parent_module) { | ||||
|              std::tuple<int, int, int> grid, | ||||
|              std::tuple<int, int, int> threadgroup, | ||||
|              std::optional<std::map<std::string, nb::handle>> template_args_, | ||||
|              std::optional<float> init_value, | ||||
|              bool verbose, | ||||
|              StreamOrDevice s) { | ||||
|             std::map<std::string, array> inputs; | ||||
| @@ -289,6 +293,7 @@ void init_fast(nb::module_& parent_module) { | ||||
|                 grid, | ||||
|                 threadgroup, | ||||
|                 template_args, | ||||
|                 init_value, | ||||
|                 verbose, | ||||
|                 s); | ||||
|           }, | ||||
| @@ -299,10 +304,11 @@ void init_fast(nb::module_& parent_module) { | ||||
|           "grid"_a, | ||||
|           "threadgroup"_a, | ||||
|           "template"_a = nb::none(), | ||||
|           "init_value"_a = nb::none(), | ||||
|           "verbose"_a = false, | ||||
|           "stream"_a = nb::none(), | ||||
|           nb::sig( | ||||
|               "def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), | ||||
|               "def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), | ||||
|           R"pbdoc( | ||||
|             Run the kernel. | ||||
|  | ||||
| @@ -316,9 +322,11 @@ void init_fast(nb::module_& parent_module) { | ||||
|               grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. | ||||
|               threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. | ||||
|               template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments. | ||||
|                   These will be added as template arguments to the kernel definition. | ||||
|                   These will be added as template arguments to the kernel definition. Default: ``None``. | ||||
|               init_value (float, optional): Optional value to use to initialize all of the output arrays. | ||||
|                   By default, output arrays are uninitialized. Default: ``None``. | ||||
|               verbose (bool, optional): Whether to print the full generated source code of the kernel | ||||
|                   when it is run. | ||||
|                   when it is run. Default: ``False``. | ||||
|               stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. | ||||
|  | ||||
|             Returns: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron