mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard * add scale * review comments --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -4379,11 +4379,11 @@ void init_ops(nb::module_& m) {
|
||||
"hadamard_transform",
|
||||
&hadamard_transform,
|
||||
nb::arg(),
|
||||
"scale"_a = 1.0,
|
||||
"scale"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def hadamard_transform(a: array, Optional[float] scale = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the Walsh-Hadamard transform along the final axis.
|
||||
|
||||
@@ -4393,7 +4393,7 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
from scipy.linalg import hadamard
|
||||
|
||||
y = hadamard(len(x)) @ x
|
||||
y = (hadamard(len(x)) @ x) * scale
|
||||
|
||||
Supports sizes ``n = m*2^k`` for ``m`` in ``(1, 12, 20, 28)`` and ``2^k
|
||||
<= 8192`` for float32 and ``2^k <= 16384`` for float16/bfloat16.
|
||||
@@ -4401,6 +4401,7 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
scale (float): Scale the output by this factor.
|
||||
Defaults to ``1/sqrt(a.shape[-1])`` so that the Hadamard matrix is orthonormal.
|
||||
|
||||
Returns:
|
||||
array: The transformed array.
|
||||
|
@@ -2496,6 +2496,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
atol = 2e-4 if dtype == np.float32 else 5e-2 * k
|
||||
np.testing.assert_allclose(y, y_np, atol=atol)
|
||||
|
||||
# bfloat16 emulation on M1 means 2**14 doesn't fit in threadgroup memory
|
||||
if dtype == np.float16 and k < 14:
|
||||
y_bf16 = mx.hadamard_transform(x.astype(mx.bfloat16), scale=scale)
|
||||
np.testing.assert_allclose(
|
||||
y_bf16.astype(mx.float16), y, atol=atol * 2
|
||||
)
|
||||
|
||||
def test_hadamard_grad_vmap(self):
|
||||
np.random.seed(4)
|
||||
|
||||
@@ -2509,7 +2516,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
c = mx.array(c).astype(mx.float32)
|
||||
|
||||
def hadamard_transform(x):
|
||||
return h @ x
|
||||
return h @ x / mx.sqrt(x.shape[-1])
|
||||
|
||||
out = mx.vjp(hadamard_transform, [x], [c])
|
||||
out_t = mx.vjp(mx.hadamard_transform, [x], [c])
|
||||
|
Reference in New Issue
Block a user