mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
GPU Hadamard for large N (#1879)
This commit is contained in:
@@ -2868,11 +2868,33 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
h28 = parse_h_string(h28_str)
|
||||
|
||||
x = mx.array(5)
|
||||
y = mx.hadamard_transform(x)
|
||||
self.assertEqual(y.item(), 5)
|
||||
|
||||
x = mx.array(5)
|
||||
y = mx.hadamard_transform(x, scale=0.2)
|
||||
self.assertEqual(y.item(), 1)
|
||||
|
||||
x = mx.random.normal((8, 8, 1))
|
||||
y = mx.hadamard_transform(x)
|
||||
self.assertTrue(mx.all(y == x).item())
|
||||
|
||||
# Too slow to compare to numpy so let's compare CPU to GPU
|
||||
if mx.default_device() == mx.gpu:
|
||||
rk = mx.random.key(42)
|
||||
for k in range(14, 17):
|
||||
for m in [1, 3, 5, 7]:
|
||||
x = mx.random.normal((4, m * 2**k), key=rk)
|
||||
y1 = mx.hadamard_transform(x, stream=mx.cpu)
|
||||
y2 = mx.hadamard_transform(x, stream=mx.gpu)
|
||||
self.assertLess(mx.abs(y1 - y2).max().item(), 5e-6)
|
||||
|
||||
np.random.seed(7)
|
||||
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15))
|
||||
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 14))
|
||||
for dtype, m, k in tests:
|
||||
# skip large m=28 cases because they're very slow in NumPy
|
||||
if (m > 1 and k > 8) or (dtype != np.float16 and k == 14):
|
||||
if m > 1 and k > 8:
|
||||
continue
|
||||
with self.subTest(dtype=dtype, m=m, k=k):
|
||||
n = m * 2**k
|
||||
|
Reference in New Issue
Block a user