Fix bfloat16 Hadamard (#1283)

* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-07-23 14:54:43 -07:00
committed by GitHub
parent e2aa6ec8ae
commit c34a5ae7f7
5 changed files with 20 additions and 10 deletions

View File

@@ -2496,6 +2496,13 @@ class TestOps(mlx_tests.MLXTestCase):
atol = 2e-4 if dtype == np.float32 else 5e-2 * k
np.testing.assert_allclose(y, y_np, atol=atol)
# bfloat16 emulation on M1 means 2**14 doesn't fit in threadgroup memory
if dtype == np.float16 and k < 14:
y_bf16 = mx.hadamard_transform(x.astype(mx.bfloat16), scale=scale)
np.testing.assert_allclose(
y_bf16.astype(mx.float16), y, atol=atol * 2
)
def test_hadamard_grad_vmap(self):
np.random.seed(4)
@@ -2509,7 +2516,7 @@ class TestOps(mlx_tests.MLXTestCase):
c = mx.array(c).astype(mx.float32)
def hadamard_transform(x):
return h @ x
return h @ x / mx.sqrt(x.shape[-1])
out = mx.vjp(hadamard_transform, [x], [c])
out_t = mx.vjp(mx.hadamard_transform, [x], [c])