GPU Hadamard for large N (#1879)

This commit is contained in:
Angelos Katharopoulos
2025-02-18 13:43:09 -08:00
parent 9daa6b003f
commit 481349495b
5 changed files with 198 additions and 128 deletions

View File

@@ -2868,11 +2868,33 @@ class TestOps(mlx_tests.MLXTestCase):
h28 = parse_h_string(h28_str)
x = mx.array(5)
y = mx.hadamard_transform(x)
self.assertEqual(y.item(), 5)
x = mx.array(5)
y = mx.hadamard_transform(x, scale=0.2)
self.assertEqual(y.item(), 1)
x = mx.random.normal((8, 8, 1))
y = mx.hadamard_transform(x)
self.assertTrue(mx.all(y == x).item())
# Too slow to compare to numpy so let's compare CPU to GPU
if mx.default_device() == mx.gpu:
rk = mx.random.key(42)
for k in range(14, 17):
for m in [1, 3, 5, 7]:
x = mx.random.normal((4, m * 2**k), key=rk)
y1 = mx.hadamard_transform(x, stream=mx.cpu)
y2 = mx.hadamard_transform(x, stream=mx.gpu)
self.assertLess(mx.abs(y1 - y2).max().item(), 5e-6)
np.random.seed(7)
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15))
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 14))
for dtype, m, k in tests:
# skip large m=28 cases because they're very slow in NumPy
if (m > 1 and k > 8) or (dtype != np.float16 and k == 14):
if m > 1 and k > 8:
continue
with self.subTest(dtype=dtype, m=m, k=k):
n = m * 2**k