mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix bug with multiple attributes (#1348)
Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
		| @@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|             uint elem = thread_position_in_grid.x; | ||||
|             uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); | ||||
|             T tmp = inp[loc]; | ||||
|             out[elem] = metal::exp(tmp); | ||||
|             out[elem] = metal::exp(tmp) * threads_per_simdgroup; | ||||
|         """ | ||||
|         source_contig = """ | ||||
|             uint elem = thread_position_in_grid.x; | ||||
|             T tmp = inp[elem]; | ||||
|             out[elem] = metal::exp(tmp); | ||||
|             out[elem] = metal::exp(tmp) * threads_per_simdgroup; | ||||
|         """ | ||||
|  | ||||
|         # non contiguous | ||||
| @@ -644,7 +644,7 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|                 output_dtypes={"out": a.dtype}, | ||||
|                 stream=mx.gpu, | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(mx.exp(a), outputs["out"])) | ||||
|             self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"])) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron