diff --git a/mlx/backend/common/hadamard.h b/mlx/backend/common/hadamard.h index a8fed76b0..ba5c4e41e 100644 --- a/mlx/backend/common/hadamard.h +++ b/mlx/backend/common/hadamard.h @@ -99,7 +99,11 @@ inline std::pair decompose_hadamard(int n) { "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); } } + if (n > (1 << 26)) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where k <= 26"); + } return {n, m}; } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index a7dfc5f17..89b970fce 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -1,9 +1,7 @@ // Copyright © 2024 Apple Inc. -#include - -#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/hadamard.h" +#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -15,7 +13,6 @@ namespace mlx::core { constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256; -constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB std::string gen_hadamard_codelet(int m) { // Generate a O(m^2) hadamard codelet for a given M @@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) { return source.str(); } -void Hadamard::eval_gpu(const std::vector& inputs, array& out) { - auto& s = stream(); +void hadamard_mn_contiguous( + const array& x, + array& y, + int m, + int n1, + int n2, + float scale, + metal::Device& d, + const Stream& s) { + int n = n1 * n2; + int read_width_n1 = n1 == 2 ? 2 : 4; + int read_width_n2 = n2 == 2 ? 2 : 4; + int read_width_m = (n == 2 || m == 28) ? 2 : 4; + int max_radix_1 = std::min(n1, 16); + int max_radix_2 = std::min(n2, 16); + float scale_n1 = 1.0; + float scale_n2 = (m == 1) ? scale : 1.0; + float scale_m = scale; - auto& in = inputs[0]; + // n2 is a row contiguous power of 2 hadamard transform + MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1); + MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1); - std::vector copies; - // Only support the last axis for now - int axis = in.ndim() - 1; - auto check_input = [&copies, &s](const array& x) { - // TODO(alexbarron) pass strides to kernel to relax this constraint - bool no_copy = x.flags().row_contiguous; - if (no_copy) { - return x; - } else { - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + // n1 is a strided power of 2 hadamard transform with stride n2 + MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1); + MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2); + + // m is a strided hadamard transform with stride n = n1 * n2 + MTL::Size group_dims_m( + std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1); + MTL::Size grid_dims_m( + group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1); + + // Make the kernel + std::string kname; + kname.reserve(32); + concatenate(kname, "hadamard_", n * m, "_", type_to_name(x)); + auto lib = d.get_library(kname, [&]() { + std::string kernel; + concatenate( + kernel, + metal::utils(), + gen_hadamard_codelet(m), + metal::hadamard(), + get_template_definition( + "n2" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n2, + max_radix_2, + read_width_n2)); + if (n1 > 1) { + kernel += get_template_definition( + "n1" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n1, + max_radix_1, + read_width_n1, + n2); } - }; - const array& in_contiguous = check_input(in); - - if (in_contiguous.is_donatable()) { - out.copy_shared_buffer(in_contiguous); - } else { - out.set_data(allocator::malloc(out.nbytes())); - } - - int n, m; - std::tie(n, m) = decompose_hadamard(in.shape(axis)); - - if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) { - throw std::invalid_argument( - "[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI"); - } - - int max_radix = std::min(n, 16); - // Use read_width 2 for m = 28 to avoid register spilling - int read_width = (n == 2 || m == 28) ? 2 : 4; - - std::ostringstream kname; - kname << "hadamard_" << n * m << "_" << type_to_name(out); - auto kernel_name = kname.str(); - auto& d = metal::device(s.device); - const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - auto codelet = gen_hadamard_codelet(m); - kernel_source << metal::utils() << codelet << metal::hadamard(); - kernel_source << get_template_definition( - "n" + kernel_name, - "hadamard_n", - get_type_string(in.dtype()), - n, - max_radix, - read_width); - kernel_source << get_template_definition( - "m" + kernel_name, - "hadamard_m", - get_type_string(in.dtype()), - n, - m, - read_width); - return kernel_source.str(); + if (m > 1) { + kernel += get_template_definition( + "m" + kname, + "hadamard_m", + get_type_string(x.dtype()), + n, + m, + read_width_m); + } + return kernel; }); - int batch_size = in.size() / n; - int threads_per = n / max_radix; - - auto& compute_encoder = d.get_command_encoder(s.index); - - auto launch_hadamard = [&](const array& in, - array& out, - const std::string& kernel_name, - float scale) { - auto kernel = d.get_kernel(kernel_name, lib); - assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); - + // Launch the strided transform for n1 + if (n1 > 1) { + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n1" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(scale, 2); - - MTL::Size group_dims = MTL::Size(1, threads_per, 1); - MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - }; - - if (m > 1) { - // When m is greater than 1, we decompose the - // computation into two uploads to the GPU: - // - // e.g. len(x) = 12*4 = 48, m = 12, n = 4 - // - // y = h48 @ x - // - // Upload 1: - // tmp = a.reshape(12, 4) @ h4 - // - // Upload 2: - // y = h12 @ tmp - array temp(in.shape(), in.dtype(), nullptr, {}); - temp.set_data(allocator::malloc(temp.nbytes())); - copies.push_back(temp); - - launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0); - - // Metal sometimes reports 256 max threads per group for hadamard_m kernel - threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP); - batch_size = in.size() / m / read_width / threads_per; - launch_hadamard(temp, out, "m" + kernel_name, scale_); - } else { - launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n1, 2); + compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1); } - d.add_temporaries(std::move(copies), s.index); + // Launch the transform for n2 + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n2" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(n1 > 1 ? y : x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n2, 2); + compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2); + + // Launch the strided transform for m + if (m > 1) { + auto kernel = d.get_kernel("m" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(y, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_m, 2); + compute_encoder.dispatch_threads(grid_dims_m, group_dims_m); + } +} + +void Hadamard::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; + + // Split the hadamard transform so that all of them work on vectors smaller + // than 8192 elements. + // + // We decompose it in the following way: + // + // n = m * n1 * n2 = m * 2^k1 * 2^k2 + // + // where m is in (1, 12, 20, 28) and n1 and n2 <= 8192 + auto [n, m] = decompose_hadamard(in.shape().back()); + int n1 = 1, n2 = n; + if (n > 8192) { + for (n2 = 2; n2 * n2 < n; n2 *= 2) { + } + n1 = n / n2; + } + + if (in.flags().row_contiguous) { + if (in.is_donatable()) { + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s); + } else { + copy_gpu(in, out, CopyType::General, s); + hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/hadamard.h b/mlx/backend/metal/kernels/hadamard.h index 93e2fb8a8..9f2311c10 100644 --- a/mlx/backend/metal/kernels/hadamard.h +++ b/mlx/backend/metal/kernels/hadamard.h @@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) { } } -template +template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -46,18 +46,25 @@ template constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); - int batch_idx = elem.x * N; - short i = elem.y; + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; threadgroup T buf[N]; // Read values from device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } @@ -113,12 +120,20 @@ template } // Write values to device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; } } } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e7abe12db..4aa5e88b7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -473,8 +473,19 @@ array hadamard_transform( std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) - float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1)); + int n = a.ndim() > 0 ? a.shape(-1) : 1; + float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n); auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; + + // Nothing to do for a scalar + if (n == 1) { + if (scale == 1) { + return a; + } + + return multiply(a, array(scale, dtype), s); + } + return array( a.shape(), dtype, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d840eac7d..d9e143d82 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2868,11 +2868,33 @@ class TestOps(mlx_tests.MLXTestCase): h28 = parse_h_string(h28_str) + x = mx.array(5) + y = mx.hadamard_transform(x) + self.assertEqual(y.item(), 5) + + x = mx.array(5) + y = mx.hadamard_transform(x, scale=0.2) + self.assertEqual(y.item(), 1) + + x = mx.random.normal((8, 8, 1)) + y = mx.hadamard_transform(x) + self.assertTrue(mx.all(y == x).item()) + + # Too slow to compare to numpy so let's compare CPU to GPU + if mx.default_device() == mx.gpu: + rk = mx.random.key(42) + for k in range(14, 17): + for m in [1, 3, 5, 7]: + x = mx.random.normal((4, m * 2**k), key=rk) + y1 = mx.hadamard_transform(x, stream=mx.cpu) + y2 = mx.hadamard_transform(x, stream=mx.gpu) + self.assertLess(mx.abs(y1 - y2).max().item(), 5e-6) + np.random.seed(7) - tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15)) + tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 14)) for dtype, m, k in tests: # skip large m=28 cases because they're very slow in NumPy - if (m > 1 and k > 8) or (dtype != np.float16 and k == 14): + if m > 1 and k > 8: continue with self.subTest(dtype=dtype, m=m, k=k): n = m * 2**k