mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 06:24:35 +08:00
Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard * add scale * review comments --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -2496,6 +2496,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
atol = 2e-4 if dtype == np.float32 else 5e-2 * k
|
||||
np.testing.assert_allclose(y, y_np, atol=atol)
|
||||
|
||||
# bfloat16 emulation on M1 means 2**14 doesn't fit in threadgroup memory
|
||||
if dtype == np.float16 and k < 14:
|
||||
y_bf16 = mx.hadamard_transform(x.astype(mx.bfloat16), scale=scale)
|
||||
np.testing.assert_allclose(
|
||||
y_bf16.astype(mx.float16), y, atol=atol * 2
|
||||
)
|
||||
|
||||
def test_hadamard_grad_vmap(self):
|
||||
np.random.seed(4)
|
||||
|
||||
@@ -2509,7 +2516,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
c = mx.array(c).astype(mx.float32)
|
||||
|
||||
def hadamard_transform(x):
|
||||
return h @ x
|
||||
return h @ x / mx.sqrt(x.shape[-1])
|
||||
|
||||
out = mx.vjp(hadamard_transform, [x], [c])
|
||||
out_t = mx.vjp(mx.hadamard_transform, [x], [c])
|
||||
|
Reference in New Issue
Block a user