mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Add optional headers to `mx.fast.metal_kernel` (#1358)
				
					
				
			This commit is contained in:
		| @@ -200,9 +200,15 @@ void init_fast(nb::module_& parent_module) { | ||||
|       A jit-compiled custom Metal kernel defined from a source string. | ||||
|       )pbdoc") | ||||
|       .def( | ||||
|           nb::init<const std::string&, const std::string&, bool, bool>(), | ||||
|           nb::init< | ||||
|               const std::string&, | ||||
|               const std::string&, | ||||
|               const std::string&, | ||||
|               bool, | ||||
|               bool>(), | ||||
|           "name"_a, | ||||
|           "source"_a, | ||||
|           "header"_a = "", | ||||
|           "ensure_row_contiguous"_a = true, | ||||
|           "atomic_outputs"_a = false, | ||||
|           R"pbdoc( | ||||
| @@ -214,6 +220,8 @@ void init_fast(nb::module_& parent_module) { | ||||
|             the function signature will be generated for you. The names of the inputs/outputs | ||||
|             are determined by the ``inputs`` and ``output_shapes``/``output_dtypes`` | ||||
|             used when the kernel is called. | ||||
|         header (str): Header source code to include before the main function. | ||||
|             Useful for helper functions or includes that should live outside of the main function body. | ||||
|         ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous | ||||
|             before the kernel runs. Default: ``True``. | ||||
|         atomic_outputs (bool): Whether to use atomic outputs in the function signature | ||||
|   | ||||
| @@ -551,7 +551,7 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|     @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") | ||||
|     def test_custom_kernel_basic(self): | ||||
|         mx.random.seed(7) | ||||
|         a = mx.random.normal(shape=(3, 6)) | ||||
|         a = mx.random.normal(shape=(2, 2)) | ||||
|         kernel = mx.fast.metal_kernel( | ||||
|             name="basic", | ||||
|             source=""" | ||||
| @@ -567,7 +567,7 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|             output_dtypes={"out1": mx.float32}, | ||||
|             stream=mx.gpu, | ||||
|         ) | ||||
|         mx.allclose(out["out1"], a[:2, :2]) | ||||
|         self.assertTrue(mx.allclose(out["out1"], a)) | ||||
|  | ||||
|     @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") | ||||
|     def test_custom_kernel_args(self): | ||||
| @@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|             uint elem = thread_position_in_grid.x; | ||||
|             uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); | ||||
|             T tmp = inp[loc]; | ||||
|             out[elem] = metal::exp(tmp) * threads_per_simdgroup; | ||||
|             out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; | ||||
|         """ | ||||
|         source_contig = """ | ||||
|             uint elem = thread_position_in_grid.x; | ||||
|             T tmp = inp[elem]; | ||||
|             out[elem] = metal::exp(tmp) * threads_per_simdgroup; | ||||
|             out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; | ||||
|         """ | ||||
|  | ||||
|         # non contiguous | ||||
| @@ -646,6 +646,33 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"])) | ||||
|  | ||||
|     @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") | ||||
|     def test_custom_kernel_helper(self): | ||||
|         mx.random.seed(7) | ||||
|         a = mx.random.normal(shape=(2, 2)) | ||||
|         kernel = mx.fast.metal_kernel( | ||||
|             name="helper", | ||||
|             header=""" | ||||
|             template <typename T> | ||||
|             T do_exp(T x) { | ||||
|                 return metal::precise::exp(x); | ||||
|             } | ||||
|             """, | ||||
|             source=""" | ||||
|                 uint elem = thread_position_in_grid.x; | ||||
|                 out1[elem] = do_exp(a[elem]); | ||||
|             """, | ||||
|         ) | ||||
|         out = kernel( | ||||
|             inputs={"a": a}, | ||||
|             grid=(4, 1, 1), | ||||
|             threadgroup=(2, 1, 1), | ||||
|             output_shapes={"out1": (2, 2)}, | ||||
|             output_dtypes={"out1": mx.float32}, | ||||
|             stream=mx.gpu, | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(out["out1"], mx.exp(a))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron