mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard * add scale * review comments --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
		| @@ -2496,6 +2496,13 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                 atol = 2e-4 if dtype == np.float32 else 5e-2 * k | ||||
|                 np.testing.assert_allclose(y, y_np, atol=atol) | ||||
|  | ||||
|                 # bfloat16 emulation on M1 means 2**14 doesn't fit in threadgroup memory | ||||
|                 if dtype == np.float16 and k < 14: | ||||
|                     y_bf16 = mx.hadamard_transform(x.astype(mx.bfloat16), scale=scale) | ||||
|                     np.testing.assert_allclose( | ||||
|                         y_bf16.astype(mx.float16), y, atol=atol * 2 | ||||
|                     ) | ||||
|  | ||||
|     def test_hadamard_grad_vmap(self): | ||||
|         np.random.seed(4) | ||||
|  | ||||
| @@ -2509,7 +2516,7 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             c = mx.array(c).astype(mx.float32) | ||||
|  | ||||
|             def hadamard_transform(x): | ||||
|                 return h @ x | ||||
|                 return h @ x / mx.sqrt(x.shape[-1]) | ||||
|  | ||||
|             out = mx.vjp(hadamard_transform, [x], [c]) | ||||
|             out_t = mx.vjp(mx.hadamard_transform, [x], [c]) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron