mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix bug with multiple attributes (#1348)
Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
parent
98b6ce3460
commit
da8deb2b62
@ -1052,6 +1052,7 @@ void write_signature(
|
|||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
// Add metal attributes e.g. `threadgroup_index_in_grid`
|
// Add metal attributes e.g. `threadgroup_index_in_grid`
|
||||||
|
index = 0;
|
||||||
for (const auto& [attr, dtype] : attrs) {
|
for (const auto& [attr, dtype] : attrs) {
|
||||||
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
|
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
|
||||||
if (index < attrs.size() - 1) {
|
if (index < attrs.size() - 1) {
|
||||||
@ -1059,6 +1060,7 @@ void write_signature(
|
|||||||
} else {
|
} else {
|
||||||
kernel_source << ") {" << std::endl;
|
kernel_source << ") {" << std::endl;
|
||||||
}
|
}
|
||||||
|
index++;
|
||||||
}
|
}
|
||||||
kernel_source << source << std::endl;
|
kernel_source << source << std::endl;
|
||||||
kernel_source << "}" << std::endl;
|
kernel_source << "}" << std::endl;
|
||||||
|
@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
T tmp = inp[loc];
|
T tmp = inp[loc];
|
||||||
out[elem] = metal::exp(tmp);
|
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
|
||||||
"""
|
"""
|
||||||
source_contig = """
|
source_contig = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
T tmp = inp[elem];
|
T tmp = inp[elem];
|
||||||
out[elem] = metal::exp(tmp);
|
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# non contiguous
|
# non contiguous
|
||||||
@ -644,7 +644,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
output_dtypes={"out": a.dtype},
|
output_dtypes={"out": a.dtype},
|
||||||
stream=mx.gpu,
|
stream=mx.gpu,
|
||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(mx.exp(a), outputs["out"]))
|
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user