diff --git a/mlx/fast.cpp b/mlx/fast.cpp index d06b4ce0e..3a7c80f8b 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1052,6 +1052,7 @@ void write_signature( index++; } // Add metal attributes e.g. `threadgroup_index_in_grid` + index = 0; for (const auto& [attr, dtype] : attrs) { kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]"; if (index < attrs.size() - 1) { @@ -1059,6 +1060,7 @@ void write_signature( } else { kernel_source << ") {" << std::endl; } + index++; } kernel_source << source << std::endl; kernel_source << "}" << std::endl; diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 121d5d482..369ac2086 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase): uint elem = thread_position_in_grid.x; uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); T tmp = inp[loc]; - out[elem] = metal::exp(tmp); + out[elem] = metal::exp(tmp) * threads_per_simdgroup; """ source_contig = """ uint elem = thread_position_in_grid.x; T tmp = inp[elem]; - out[elem] = metal::exp(tmp); + out[elem] = metal::exp(tmp) * threads_per_simdgroup; """ # non contiguous @@ -644,7 +644,7 @@ class TestFast(mlx_tests.MLXTestCase): output_dtypes={"out": a.dtype}, stream=mx.gpu, ) - self.assertTrue(mx.allclose(mx.exp(a), outputs["out"])) + self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"])) if __name__ == "__main__":