Fast Hadamard Transform (#1249)

* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
This commit is contained in:
Alex Barron
2024-07-09 20:39:01 -07:00
committed by GitHub
parent 03cf033f82
commit a3c287354f
22 changed files with 878 additions and 11 deletions

View File

@@ -2425,6 +2425,104 @@ class TestOps(mlx_tests.MLXTestCase):
a_out = out.view(mx.int32)
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
def _hadamard(self, N):
# Matches scipy.linalg.hadamard
H = np.array([[1]], dtype=np.int64)
for i in range(0, np.log2(N).astype(np.int64)):
H = np.vstack((np.hstack((H, H)), np.hstack((H, -H))))
return H
def test_hadamard(self):
h28_str = """
+------++----++-+--+-+--++--
-+-----+++-----+-+--+-+--++-
--+-----+++---+-+-+----+--++
---+-----+++---+-+-+-+--+--+
----+-----+++---+-+-+++--+--
-----+-----++++--+-+--++--+-
------++----++-+--+-+--++--+
--++++-+-------++--+++-+--+-
---++++-+-----+-++--+-+-+--+
+---+++--+----++-++--+-+-+--
++---++---+----++-++--+-+-+-
+++---+----+----++-++--+-+-+
++++--------+-+--++-++--+-+-
-++++--------+++--++--+--+-+
-+-++-++--++--+--------++++-
+-+-++--+--++--+--------++++
-+-+-++--+--++--+----+---+++
+-+-+-++--+--+---+---++---++
++-+-+-++--+------+--+++---+
-++-+-+-++--+------+-++++---
+-++-+---++--+------+-++++--
-++--++-+-++-+++----++------
+-++--++-+-++-+++-----+-----
++-++---+-+-++-+++-----+----
-++-++-+-+-+-+--+++-----+---
--++-++++-+-+----+++-----+--
+--++-+-++-+-+----+++-----+-
++--++-+-++-+-+----++------+
"""
def parse_h_string(h_str):
return np.array(
[[1 if s == "+" else -1 for s in row] for row in h_str.split()]
)
h28 = parse_h_string(h28_str)
np.random.seed(7)
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15))
for dtype, m, k in tests:
# skip large m=28 cases because they're very slow in NumPy
if (m > 1 and k > 8) or (dtype != np.float16 and k == 14):
continue
with self.subTest(dtype=dtype, m=m, k=k):
n = m * 2**k
b = 4
scale = 0.34
x = np.random.normal(size=(b, n)).astype(dtype)
# contiguity check
x = mx.array(x)[::2]
y = mx.hadamard_transform(x, scale=scale)
mx.eval(y)
h = (
self._hadamard(2**k)
if m == 1
else np.kron(h28, self._hadamard(2**k))
)
y_np = np.einsum("ij,bj->bi", h, x) * scale
atol = 2e-4 if dtype == np.float32 else 5e-2 * k
np.testing.assert_allclose(y, y_np, atol=atol)
def test_hadamard_grad_vmap(self):
np.random.seed(4)
for k in range(2, 8):
n = 2**k
x = np.random.normal(size=(n,))
h = self._hadamard(n)
c = np.random.normal(size=(n,))
x = mx.array(x).astype(mx.float32)
h = mx.array(h).astype(mx.float32)
c = mx.array(c).astype(mx.float32)
def hadamard_transform(x):
return h @ x
out = mx.vjp(hadamard_transform, [x], [c])
out_t = mx.vjp(mx.hadamard_transform, [x], [c])
np.testing.assert_allclose(out, out_t, atol=1e-4)
for axis in (0, 1, 2):
vht = mx.vmap(mx.vmap(hadamard_transform, 0, 0), axis, axis)
vht_t = mx.vmap(mx.vmap(mx.hadamard_transform, 0, 0), axis, axis)
xb = mx.array(np.random.normal(size=(n, n, n)))
out = vht(xb)
out_t = vht_t(xb)
np.testing.assert_allclose(out, out_t, atol=1e-4)
if __name__ == "__main__":
unittest.main()