From c34a5ae7f7057a7cb87807d91b3e910783e6d444 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Tue, 23 Jul 2024 14:54:43 -0700 Subject: [PATCH] Fix bfloat16 Hadamard (#1283) * fix bfloat16 hadamard * add scale * review comments --------- Co-authored-by: Alex Barron --- mlx/backend/metal/kernels/hadamard.h | 8 ++++---- mlx/ops.cpp | 4 +++- mlx/ops.h | 2 +- python/src/ops.cpp | 7 ++++--- python/tests/test_ops.py | 9 ++++++++- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/mlx/backend/metal/kernels/hadamard.h b/mlx/backend/metal/kernels/hadamard.h index da4050cf4..93e2fb8a8 100644 --- a/mlx/backend/metal/kernels/hadamard.h +++ b/mlx/backend/metal/kernels/hadamard.h @@ -80,7 +80,7 @@ template STEEL_PRAGMA_UNROLL for (short r = 0; r < max_radix; r++) { - buf[j + h * r] = x[r]; + buf[j + h * r] = T(x[r]); } h <<= logR; @@ -106,7 +106,7 @@ template STEEL_PRAGMA_UNROLL for (short r = 0; r < final_radix; r++) { - buf[j + h * r] = x[r]; + buf[j + h * r] = T(x[r]); } } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -118,7 +118,7 @@ template short index = j * read_width * num_threads + i * read_width; STEEL_PRAGMA_UNROLL for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = buf[index + r] * scale; + out[batch_idx + index + r] = T(buf[index + r] * scale); } } } @@ -161,7 +161,7 @@ template for (short c = 0; c < M; c++) { STEEL_PRAGMA_UNROLL for (short r = 0; r < read_width; r++) { - out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale; + out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale); } } } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f9051e243..5f17241b4 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -453,8 +453,10 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) { array hadamard_transform( const array& a, - float scale /* = 1.0 */, + std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { + // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) + float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1)); auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; return array( a.shape(), diff --git a/mlx/ops.h b/mlx/ops.h index fb07bf1fa..d637633b6 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -134,7 +134,7 @@ array flatten(const array& a, StreamOrDevice s = {}); /** Multiply the array by the Hadamard matrix of corresponding size. */ array hadamard_transform( const array& a, - float scale = 1.0f, + std::optional scale = std::nullopt, StreamOrDevice s = {}); /** Remove singleton dimensions at the given axes. */ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index e59e56ebc..f2b10a5dd 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4379,11 +4379,11 @@ void init_ops(nb::module_& m) { "hadamard_transform", &hadamard_transform, nb::arg(), - "scale"_a = 1.0, + "scale"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"), + "def hadamard_transform(a: array, Optional[float] scale = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the Walsh-Hadamard transform along the final axis. @@ -4393,7 +4393,7 @@ void init_ops(nb::module_& m) { from scipy.linalg import hadamard - y = hadamard(len(x)) @ x + y = (hadamard(len(x)) @ x) * scale Supports sizes ``n = m*2^k`` for ``m`` in ``(1, 12, 20, 28)`` and ``2^k <= 8192`` for float32 and ``2^k <= 16384`` for float16/bfloat16. @@ -4401,6 +4401,7 @@ void init_ops(nb::module_& m) { Args: a (array): Input array or scalar. scale (float): Scale the output by this factor. + Defaults to ``1/sqrt(a.shape[-1])`` so that the Hadamard matrix is orthonormal. Returns: array: The transformed array. diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 60a0118c0..5b83613f7 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2496,6 +2496,13 @@ class TestOps(mlx_tests.MLXTestCase): atol = 2e-4 if dtype == np.float32 else 5e-2 * k np.testing.assert_allclose(y, y_np, atol=atol) + # bfloat16 emulation on M1 means 2**14 doesn't fit in threadgroup memory + if dtype == np.float16 and k < 14: + y_bf16 = mx.hadamard_transform(x.astype(mx.bfloat16), scale=scale) + np.testing.assert_allclose( + y_bf16.astype(mx.float16), y, atol=atol * 2 + ) + def test_hadamard_grad_vmap(self): np.random.seed(4) @@ -2509,7 +2516,7 @@ class TestOps(mlx_tests.MLXTestCase): c = mx.array(c).astype(mx.float32) def hadamard_transform(x): - return h @ x + return h @ x / mx.sqrt(x.shape[-1]) out = mx.vjp(hadamard_transform, [x], [c]) out_t = mx.vjp(mx.hadamard_transform, [x], [c])