mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2 * working for m*2^k * add scale and check contiguity * add size check * clean up * fix test * add grads + vmap * gpu only * skip on linux * test typo * add cpu impl * remove gpu only tests * fix linux build + add is_equivalent
This commit is contained in:
@@ -4372,6 +4372,35 @@ void init_ops(nb::module_& m) {
|
||||
a (array): Input array or scalar.
|
||||
dtype (Dtype): The data type to change to.
|
||||
|
||||
Returns:
|
||||
array: The array with the new type.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"hadamard_transform",
|
||||
&hadamard_transform,
|
||||
nb::arg(),
|
||||
"scale"_a = 1.0,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the Walsh-Hadamard transform along the final axis.
|
||||
|
||||
Equivalent to:
|
||||
```python
|
||||
from scipy.linalg import hadamard
|
||||
|
||||
y = hadamard(len(x)) @ x
|
||||
```
|
||||
|
||||
Supports sizes `n = m*2^k` where m in (1, 12, 20, 28)
|
||||
and 2^k <= 8192 for FP32 and 2^k <= 16384 for FP16/BF16.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
scale (float): Scale the output by this factor.
|
||||
|
||||
Returns:
|
||||
array: The array with the new type.
|
||||
)pbdoc");
|
||||
|
||||
@@ -2425,6 +2425,104 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a_out = out.view(mx.int32)
|
||||
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
|
||||
|
||||
def _hadamard(self, N):
|
||||
# Matches scipy.linalg.hadamard
|
||||
H = np.array([[1]], dtype=np.int64)
|
||||
for i in range(0, np.log2(N).astype(np.int64)):
|
||||
H = np.vstack((np.hstack((H, H)), np.hstack((H, -H))))
|
||||
return H
|
||||
|
||||
def test_hadamard(self):
|
||||
h28_str = """
|
||||
+------++----++-+--+-+--++--
|
||||
-+-----+++-----+-+--+-+--++-
|
||||
--+-----+++---+-+-+----+--++
|
||||
---+-----+++---+-+-+-+--+--+
|
||||
----+-----+++---+-+-+++--+--
|
||||
-----+-----++++--+-+--++--+-
|
||||
------++----++-+--+-+--++--+
|
||||
--++++-+-------++--+++-+--+-
|
||||
---++++-+-----+-++--+-+-+--+
|
||||
+---+++--+----++-++--+-+-+--
|
||||
++---++---+----++-++--+-+-+-
|
||||
+++---+----+----++-++--+-+-+
|
||||
++++--------+-+--++-++--+-+-
|
||||
-++++--------+++--++--+--+-+
|
||||
-+-++-++--++--+--------++++-
|
||||
+-+-++--+--++--+--------++++
|
||||
-+-+-++--+--++--+----+---+++
|
||||
+-+-+-++--+--+---+---++---++
|
||||
++-+-+-++--+------+--+++---+
|
||||
-++-+-+-++--+------+-++++---
|
||||
+-++-+---++--+------+-++++--
|
||||
-++--++-+-++-+++----++------
|
||||
+-++--++-+-++-+++-----+-----
|
||||
++-++---+-+-++-+++-----+----
|
||||
-++-++-+-+-+-+--+++-----+---
|
||||
--++-++++-+-+----+++-----+--
|
||||
+--++-+-++-+-+----+++-----+-
|
||||
++--++-+-++-+-+----++------+
|
||||
"""
|
||||
|
||||
def parse_h_string(h_str):
|
||||
return np.array(
|
||||
[[1 if s == "+" else -1 for s in row] for row in h_str.split()]
|
||||
)
|
||||
|
||||
h28 = parse_h_string(h28_str)
|
||||
|
||||
np.random.seed(7)
|
||||
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15))
|
||||
for dtype, m, k in tests:
|
||||
# skip large m=28 cases because they're very slow in NumPy
|
||||
if (m > 1 and k > 8) or (dtype != np.float16 and k == 14):
|
||||
continue
|
||||
with self.subTest(dtype=dtype, m=m, k=k):
|
||||
n = m * 2**k
|
||||
b = 4
|
||||
scale = 0.34
|
||||
x = np.random.normal(size=(b, n)).astype(dtype)
|
||||
# contiguity check
|
||||
x = mx.array(x)[::2]
|
||||
y = mx.hadamard_transform(x, scale=scale)
|
||||
mx.eval(y)
|
||||
h = (
|
||||
self._hadamard(2**k)
|
||||
if m == 1
|
||||
else np.kron(h28, self._hadamard(2**k))
|
||||
)
|
||||
y_np = np.einsum("ij,bj->bi", h, x) * scale
|
||||
atol = 2e-4 if dtype == np.float32 else 5e-2 * k
|
||||
np.testing.assert_allclose(y, y_np, atol=atol)
|
||||
|
||||
def test_hadamard_grad_vmap(self):
|
||||
np.random.seed(4)
|
||||
|
||||
for k in range(2, 8):
|
||||
n = 2**k
|
||||
x = np.random.normal(size=(n,))
|
||||
h = self._hadamard(n)
|
||||
c = np.random.normal(size=(n,))
|
||||
x = mx.array(x).astype(mx.float32)
|
||||
h = mx.array(h).astype(mx.float32)
|
||||
c = mx.array(c).astype(mx.float32)
|
||||
|
||||
def hadamard_transform(x):
|
||||
return h @ x
|
||||
|
||||
out = mx.vjp(hadamard_transform, [x], [c])
|
||||
out_t = mx.vjp(mx.hadamard_transform, [x], [c])
|
||||
np.testing.assert_allclose(out, out_t, atol=1e-4)
|
||||
|
||||
for axis in (0, 1, 2):
|
||||
vht = mx.vmap(mx.vmap(hadamard_transform, 0, 0), axis, axis)
|
||||
vht_t = mx.vmap(mx.vmap(mx.hadamard_transform, 0, 0), axis, axis)
|
||||
|
||||
xb = mx.array(np.random.normal(size=(n, n, n)))
|
||||
out = vht(xb)
|
||||
out_t = vht_t(xb)
|
||||
np.testing.assert_allclose(out, out_t, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user