mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
fix bug with multiple attributes (#1348)
Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
uint elem = thread_position_in_grid.x;
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
out[elem] = metal::exp(tmp);
|
||||
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
source_contig = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
|
||||
"""
|
||||
|
||||
# non contiguous
|
||||
@@ -644,7 +644,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
output_dtypes={"out": a.dtype},
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(mx.exp(a), outputs["out"]))
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user