mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fix unintuitive metal kernel caching (#2242)
* Fix unintuitive metal kernel caching * alternative solution
This commit is contained in:
		| @@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|         )[0] | ||||
|         self.assertEqual(out.item(), 2) | ||||
|  | ||||
|     @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") | ||||
|     def test_custom_kernel_caching(self): | ||||
|         def call_kernel(a: mx.array, source): | ||||
|             kernel = mx.fast.metal_kernel( | ||||
|                 name="my_kernel", | ||||
|                 input_names=["inp"], | ||||
|                 output_names=["out"], | ||||
|                 source=source, | ||||
|             ) | ||||
|             return kernel( | ||||
|                 inputs=[a], | ||||
|                 grid=(a.size, 1, 1), | ||||
|                 threadgroup=(a.size, 1, 1), | ||||
|                 output_shapes=[a.shape], | ||||
|                 output_dtypes=[a.dtype], | ||||
|                 stream=mx.gpu, | ||||
|             )[0] | ||||
|  | ||||
|         a = mx.random.normal(shape=(32,)) | ||||
|  | ||||
|         source = """ | ||||
|             uint elem = thread_position_in_grid.x; | ||||
|             out[elem] = 0.0; | ||||
|         """ | ||||
|  | ||||
|         out = call_kernel(a, source) | ||||
|         self.assertTrue(mx.array_equal(out, mx.zeros_like(out))) | ||||
|  | ||||
|         source = """ | ||||
|             uint elem = thread_position_in_grid.x; | ||||
|             out[elem] = 1.0; | ||||
|         """ | ||||
|         out = call_kernel(a, source) | ||||
|         self.assertTrue(mx.array_equal(out, mx.ones_like(out))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun