fix bug with multiple attributes (#1348)

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron 2024-08-23 10:06:15 -07:00 committed by GitHub
parent 98b6ce3460
commit da8deb2b62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 3 deletions

View File

@ -1052,6 +1052,7 @@ void write_signature(
index++; index++;
} }
// Add metal attributes e.g. `threadgroup_index_in_grid` // Add metal attributes e.g. `threadgroup_index_in_grid`
index = 0;
for (const auto& [attr, dtype] : attrs) { for (const auto& [attr, dtype] : attrs) {
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]"; kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
if (index < attrs.size() - 1) { if (index < attrs.size() - 1) {
@ -1059,6 +1060,7 @@ void write_signature(
} else { } else {
kernel_source << ") {" << std::endl; kernel_source << ") {" << std::endl;
} }
index++;
} }
kernel_source << source << std::endl; kernel_source << source << std::endl;
kernel_source << "}" << std::endl; kernel_source << "}" << std::endl;

View File

@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase):
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc]; T tmp = inp[loc];
out[elem] = metal::exp(tmp); out[elem] = metal::exp(tmp) * threads_per_simdgroup;
""" """
source_contig = """ source_contig = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
T tmp = inp[elem]; T tmp = inp[elem];
out[elem] = metal::exp(tmp); out[elem] = metal::exp(tmp) * threads_per_simdgroup;
""" """
# non contiguous # non contiguous
@ -644,7 +644,7 @@ class TestFast(mlx_tests.MLXTestCase):
output_dtypes={"out": a.dtype}, output_dtypes={"out": a.dtype},
stream=mx.gpu, stream=mx.gpu,
) )
self.assertTrue(mx.allclose(mx.exp(a), outputs["out"])) self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
if __name__ == "__main__": if __name__ == "__main__":