mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2 * working for m*2^k * add scale and check contiguity * add size check * clean up * fix test * add grads + vmap * gpu only * skip on linux * test typo * add cpu impl * remove gpu only tests * fix linux build + add is_equivalent
This commit is contained in:
@@ -2425,6 +2425,104 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a_out = out.view(mx.int32)
|
||||
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
|
||||
|
||||
def _hadamard(self, N):
|
||||
# Matches scipy.linalg.hadamard
|
||||
H = np.array([[1]], dtype=np.int64)
|
||||
for i in range(0, np.log2(N).astype(np.int64)):
|
||||
H = np.vstack((np.hstack((H, H)), np.hstack((H, -H))))
|
||||
return H
|
||||
|
||||
def test_hadamard(self):
|
||||
h28_str = """
|
||||
+------++----++-+--+-+--++--
|
||||
-+-----+++-----+-+--+-+--++-
|
||||
--+-----+++---+-+-+----+--++
|
||||
---+-----+++---+-+-+-+--+--+
|
||||
----+-----+++---+-+-+++--+--
|
||||
-----+-----++++--+-+--++--+-
|
||||
------++----++-+--+-+--++--+
|
||||
--++++-+-------++--+++-+--+-
|
||||
---++++-+-----+-++--+-+-+--+
|
||||
+---+++--+----++-++--+-+-+--
|
||||
++---++---+----++-++--+-+-+-
|
||||
+++---+----+----++-++--+-+-+
|
||||
++++--------+-+--++-++--+-+-
|
||||
-++++--------+++--++--+--+-+
|
||||
-+-++-++--++--+--------++++-
|
||||
+-+-++--+--++--+--------++++
|
||||
-+-+-++--+--++--+----+---+++
|
||||
+-+-+-++--+--+---+---++---++
|
||||
++-+-+-++--+------+--+++---+
|
||||
-++-+-+-++--+------+-++++---
|
||||
+-++-+---++--+------+-++++--
|
||||
-++--++-+-++-+++----++------
|
||||
+-++--++-+-++-+++-----+-----
|
||||
++-++---+-+-++-+++-----+----
|
||||
-++-++-+-+-+-+--+++-----+---
|
||||
--++-++++-+-+----+++-----+--
|
||||
+--++-+-++-+-+----+++-----+-
|
||||
++--++-+-++-+-+----++------+
|
||||
"""
|
||||
|
||||
def parse_h_string(h_str):
|
||||
return np.array(
|
||||
[[1 if s == "+" else -1 for s in row] for row in h_str.split()]
|
||||
)
|
||||
|
||||
h28 = parse_h_string(h28_str)
|
||||
|
||||
np.random.seed(7)
|
||||
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15))
|
||||
for dtype, m, k in tests:
|
||||
# skip large m=28 cases because they're very slow in NumPy
|
||||
if (m > 1 and k > 8) or (dtype != np.float16 and k == 14):
|
||||
continue
|
||||
with self.subTest(dtype=dtype, m=m, k=k):
|
||||
n = m * 2**k
|
||||
b = 4
|
||||
scale = 0.34
|
||||
x = np.random.normal(size=(b, n)).astype(dtype)
|
||||
# contiguity check
|
||||
x = mx.array(x)[::2]
|
||||
y = mx.hadamard_transform(x, scale=scale)
|
||||
mx.eval(y)
|
||||
h = (
|
||||
self._hadamard(2**k)
|
||||
if m == 1
|
||||
else np.kron(h28, self._hadamard(2**k))
|
||||
)
|
||||
y_np = np.einsum("ij,bj->bi", h, x) * scale
|
||||
atol = 2e-4 if dtype == np.float32 else 5e-2 * k
|
||||
np.testing.assert_allclose(y, y_np, atol=atol)
|
||||
|
||||
def test_hadamard_grad_vmap(self):
|
||||
np.random.seed(4)
|
||||
|
||||
for k in range(2, 8):
|
||||
n = 2**k
|
||||
x = np.random.normal(size=(n,))
|
||||
h = self._hadamard(n)
|
||||
c = np.random.normal(size=(n,))
|
||||
x = mx.array(x).astype(mx.float32)
|
||||
h = mx.array(h).astype(mx.float32)
|
||||
c = mx.array(c).astype(mx.float32)
|
||||
|
||||
def hadamard_transform(x):
|
||||
return h @ x
|
||||
|
||||
out = mx.vjp(hadamard_transform, [x], [c])
|
||||
out_t = mx.vjp(mx.hadamard_transform, [x], [c])
|
||||
np.testing.assert_allclose(out, out_t, atol=1e-4)
|
||||
|
||||
for axis in (0, 1, 2):
|
||||
vht = mx.vmap(mx.vmap(hadamard_transform, 0, 0), axis, axis)
|
||||
vht_t = mx.vmap(mx.vmap(mx.hadamard_transform, 0, 0), axis, axis)
|
||||
|
||||
xb = mx.array(np.random.normal(size=(n, n, n)))
|
||||
out = vht(xb)
|
||||
out_t = vht_t(xb)
|
||||
np.testing.assert_allclose(out, out_t, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user