diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 518c26c70..5aabaf388 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -585,7 +585,7 @@ class TestFast(mlx_tests.MLXTestCase): def test_custom_kernel_basic(self): if mx.metal.is_available(): source = """ - uint elem thread_position_in_grid.x; + uint elem = thread_position_in_grid.x; out1[elem] = a[elem]; """ custom_kernel = mx.fast.metal_kernel