diff --git a/benchmarks/python/hadamard_bench.py b/benchmarks/python/hadamard_bench.py new file mode 100644 index 000000000..2fe45473c --- /dev/null +++ b/benchmarks/python/hadamard_bench.py @@ -0,0 +1,70 @@ +import argparse + +import matplotlib +import mlx.core as mx +import numpy as np +from time_utils import measure_runtime + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +def had(x): + y = mx.hadamard_transform(x) + mx.eval(y) + + +def copy(x): + y = x + 1.0 + mx.eval(y) + + +def run(dtype): + system_size = 2**26 + outputs = {} + for test_fn in (had, copy): + for m in [1, 12, 20, 28]: + if test_fn == copy: + key = "copy" + elif m == 1: + key = "had_2^k" + else: + key = "had_m*2^k" + outputs.setdefault(key, {}) + for k in range(7, 14): + n = m * 2**k + if n > 2**15: + continue + x_np = np.random.normal(size=(system_size // n, n)).astype(dtype) + x = mx.array(x_np) + runtime_ms = measure_runtime(test_fn, x=x) + bytes_per_gb = 1e9 + ms_per_s = 1e3 + bytes_per_had = np.dtype(x_np.dtype).itemsize * 2 + bandwidth_gb = ( + system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb + ) + print(n, bandwidth_gb) + outputs[key][n] = bandwidth_gb + + colors = { + "copy": "black", + "had_2^k": "steelblue", + "had_m*2^k": "skyblue", + } + for key, output in outputs.items(): + plt.scatter(output.keys(), output.values(), color=colors[key], label=key) + plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}") + plt.xlabel("N") + plt.ylabel("Bandwidth (GB/s)") + plt.legend() + plt.savefig(f"bench_{dtype.__name__}.png") + plt.clf() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fp16", action="store_true") + args = parser.parse_args() + dtype = np.float16 if args.fp16 else np.float32 + run(dtype) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 0b50ea244..0d75f7d62 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -72,6 +72,7 @@ Operations gather_qmm greater greater_equal + hadamard_transform identity inner isclose diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 778840115..6de0a6416 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -50,6 +50,7 @@ DEFAULT(GatherMM) DEFAULT(GatherQMM) DEFAULT(Greater) DEFAULT(GreaterEqual) +DEFAULT(Hadamard) DEFAULT(Less) DEFAULT(LessEqual) DEFAULT(Load) diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index b5650b395..aa0f3dab0 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -42,6 +42,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 5164d9579..c5b5e44b8 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -68,6 +68,7 @@ DEFAULT(Full) DEFAULT(Gather) DEFAULT(Greater) DEFAULT(GreaterEqual) +DEFAULT(Hadamard) DEFAULT(Less) DEFAULT(LessEqual) DEFAULT(Load) diff --git a/mlx/backend/common/hadamard.cpp b/mlx/backend/common/hadamard.cpp new file mode 100644 index 000000000..6c71eaf9d --- /dev/null +++ b/mlx/backend/common/hadamard.cpp @@ -0,0 +1,107 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/hadamard.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +// n = 2^k component +template +void hadamard_n(array& out, int n, int m, float scale) { + for (int b = 0; b < out.size() / n; b++) { + size_t loc = b * n; + T* data_ptr = out.data() + loc; + int h = 1; + int n_over_2 = n / 2; + while (h < n) { + for (int i = 0; i < n / 2; i++) { + int k = i & (h - 1); + int j = ((i - k) << 1) + k; + float x = *(data_ptr + j); + float y = *(data_ptr + j + h); + *(data_ptr + j) = x + y; + *(data_ptr + j + h) = x - y; + if (h == n_over_2) { + *(data_ptr + j) *= scale; + *(data_ptr + j + h) *= scale; + } + } + h <<= 1; + } + } +} + +// m component +template +void hadamard_m(array& out, int n, int m, float scale) { + auto h_matrices = hadamard_matrices(); + auto& matrix = h_matrices[m]; + auto start = 1; + auto end = matrix.find('\n', start); + std::vector hmat_vec; + while (end != std::string_view::npos) { + auto row = matrix.substr(start, end - start); + for (int i = 0; i < row.length(); i++) { + hmat_vec.push_back(row[i] == '+'); + } + start = end + 1; + end = matrix.find('\n', start); + } + + for (int b = 0; b < out.size() / m / n; b++) { + size_t loc = b * n * m; + T* data_ptr = out.data() + loc; + for (int i = 0; i < n; i++) { + std::vector out(m); + for (int j = 0; j < m; j++) { + for (int k = 0; k < m; k++) { + float x = *(data_ptr + i + k * n); + if (hmat_vec[k + j * m]) { + out[j] += x; + } else { + out[j] -= x; + } + } + } + for (int j = 0; j < m; j++) { + *(data_ptr + i + j * n) = out[j] * scale; + } + } + } +} + +template +void hadamard(array& out, int n, int m, float scale) { + float n_scale = m > 1 ? 1.0 : scale; + hadamard_n(out, n, m, n_scale); + if (m > 1) { + hadamard_m(out, n, m, scale); + } +} + +void Hadamard::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + + // Copy input to output + copy(in, out, CopyType::General); + + int axis = out.ndim() - 1; + auto [n, m] = decompose_hadamard(out.shape(axis)); + + switch (in.dtype()) { + case float32: + return hadamard(out, n, m, scale_); + case float16: + return hadamard(out, n, m, scale_); + case bfloat16: + return hadamard(out, n, m, scale_); + default: + throw std::invalid_argument("[hadamard] Unsupported type."); + } +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/common/hadamard.h b/mlx/backend/common/hadamard.h new file mode 100644 index 000000000..a8fed76b0 --- /dev/null +++ b/mlx/backend/common/hadamard.h @@ -0,0 +1,105 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/utils.h" + +namespace mlx::core { + +// From http://neilsloane.com/hadamard/ +constexpr std::string_view h12 = R"( ++-++++++++++ +--+-+-+-+-+- ++++-++----++ ++---+--+-++- ++++++-++---- ++-+---+--+-+ +++--+++-++-- ++--++---+--+ +++----+++-++ ++--+-++---+- +++++----+++- ++-+--+-++--- +)"; + +constexpr std::string_view h20 = R"( ++----+----++--++-++- +-+----+---+++---+-++ +--+----+---+++-+-+-+ +---+----+---+++++-+- +----+----++--++-++-+ +-+++++-----+--+++--+ ++-+++-+---+-+--+++-- +++-++--+---+-+--+++- ++++-+---+---+-+--+++ +++++-----++--+-+--++ +--++-+-++-+-----++++ +---++-+-++-+---+-+++ ++---++-+-+--+--++-++ +++---++-+----+-+++-+ +-++---++-+----+++++- +-+--+--++-+----+---- ++-+-----++-+----+--- +-+-+-+---+--+----+-- +--+-+++------+----+- ++--+--++------+----+ +)"; + +constexpr std::string_view h28 = R"( ++------++----++-+--+-+--++-- +-+-----+++-----+-+--+-+--++- +--+-----+++---+-+-+----+--++ +---+-----+++---+-+-+-+--+--+ +----+-----+++---+-+-+++--+-- +-----+-----++++--+-+--++--+- +------++----++-+--+-+--++--+ +--++++-+-------++--+++-+--+- +---++++-+-----+-++--+-+-+--+ ++---+++--+----++-++--+-+-+-- +++---++---+----++-++--+-+-+- ++++---+----+----++-++--+-+-+ +++++--------+-+--++-++--+-+- +-++++--------+++--++--+--+-+ +-+-++-++--++--+--------++++- ++-+-++--+--++--+--------++++ +-+-+-++--+--++--+----+---+++ ++-+-+-++--+--+---+---++---++ +++-+-+-++--+------+--+++---+ +-++-+-+-++--+------+-++++--- ++-++-+---++--+------+-++++-- +-++--++-+-++-+++----++------ ++-++--++-+-++-+++-----+----- +++-++---+-+-++-+++-----+---- +-++-++-+-+-+-+--+++-----+--- +--++-++++-+-+----+++-----+-- ++--++-+-++-+-+----+++-----+- +++--++-+-++-+-+----++------+ +)"; + +inline const std::map hadamard_matrices() { + return {{12, h12}, {20, h20}, {28, h28}}; +} + +inline std::pair decompose_hadamard(int n) { + // n = m*2^k + int m = 1; + if (!is_power_of_2(n)) { + auto h_matrices = hadamard_matrices(); + for (auto [factor, _] : h_matrices) { + if (n % factor == 0) { + m = factor; + n /= factor; + break; + } + } + if (m == 1) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); + } + } + return {n, m}; +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index b23c5af36..2a58df9cc 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -52,6 +52,7 @@ make_jit_source( ) make_jit_source(scatter) make_jit_source(gather) +make_jit_source(hadamard) if (MLX_METAL_JIT) target_sources( @@ -132,6 +133,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 5140a8e7a..de0edce51 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -14,6 +14,7 @@ #include "mlx/backend/metal/utils.h" #include "mlx/mlx.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core { diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp new file mode 100644 index 000000000..46d77b03e --- /dev/null +++ b/mlx/backend/metal/hadamard.cpp @@ -0,0 +1,203 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/hadamard.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/jit/includes.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256; +constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB + +std::string gen_hadamard_codelet(int m) { + // Generate a O(m^2) hadamard codelet for a given M + // using the hadamard matrices above + // + // e.g. m = 2 + // METAL_FUNC void hadamard_m(thread float *x) { + // float tmp[2]; + // tmp[0] = + x[0] + x[1]; + // tmp[1] = + x[0] - x[1]; + // for (int i = 0; i < 2; i++) { x[i] = tmp[i]; } + // } + // + auto h_matrices = hadamard_matrices(); + auto& matrix = h_matrices[m]; + + std::ostringstream source; + source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl; + if (m == 1) { + source << "}" << std::endl; + return source.str(); + } + source << " float tmp[" << m << "];" << std::endl; + auto start = 1; + auto end = matrix.find('\n', start); + + int index = 0; + while (end != std::string_view::npos) { + source << " tmp[" << index << "] = "; + auto row = matrix.substr(start, end - start); + for (int i = 0; i < row.length(); i++) { + source << " " << row[i] << " x[" << i << "]"; + } + source << ";" << std::endl; + start = end + 1; + end = matrix.find('\n', start); + index++; + } + source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }" + << std::endl; + source << "}" << std::endl; + return source.str(); +} + +void launch_hadamard( + const array& in, + array& out, + int batch_size, + int threads_per, + const std::string kernel_name, + float scale, + const Stream& s) { + auto& d = metal::device(s.device); + + const auto& lib_name = kernel_name.substr(1); + auto lib = d.get_library(lib_name); + auto kernel = d.get_kernel(kernel_name, lib); + assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder->setBytes(&scale, sizeof(float), 2); + + MTL::Size group_dims = MTL::Size(1, threads_per, 1); + MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +void Hadamard::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + + auto& in = inputs[0]; + + std::vector copies; + // Only support the last axis for now + int axis = in.ndim() - 1; + auto check_input = [&copies, &s](const array& x) { + // TODO(alexbarron) pass strides to kernel to relax this constraint + bool no_copy = x.flags().row_contiguous; + if (no_copy) { + return x; + } else { + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); + } + }; + const array& in_contiguous = check_input(in); + + if (in_contiguous.is_donatable()) { + out.move_shared_buffer(in_contiguous); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + auto [n, m] = decompose_hadamard(in.shape(axis)); + + if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) { + throw std::invalid_argument( + "[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI"); + } + + int max_radix = std::min(n, 16); + // Use read_width 2 for m = 28 to avoid register spilling + int read_width = (n == 2 || m == 28) ? 2 : 4; + + std::ostringstream kname; + kname << "hadamard_" << n * m << "_" << type_to_name(out); + auto kernel_name = kname.str(); + auto& d = metal::device(s.device); + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + auto codelet = gen_hadamard_codelet(m); + kernel_source << metal::utils() << codelet << metal::hadamard(); + kernel_source << get_template_definition( + "n" + kernel_name, + "hadamard_n", + get_type_string(in.dtype()), + n, + max_radix, + read_width); + kernel_source << get_template_definition( + "m" + kernel_name, + "hadamard_m", + get_type_string(in.dtype()), + n, + m, + read_width); + lib = d.get_library(lib_name, kernel_source.str()); + } + + int batch_size = in.size() / n; + int threads_per = n / max_radix; + + if (m > 1) { + // When m is greater than 1, we decompose the + // computation into two uploads to the GPU: + // + // e.g. len(x) = 12*4 = 48, m = 12, n = 4 + // + // y = h48 @ x + // + // Upload 1: + // tmp = a.reshape(12, 4) @ h4 + // + // Upload 2: + // y = h12 @ tmp + array temp(in.shape(), in.dtype(), nullptr, {}); + temp.set_data(allocator::malloc_or_wait(temp.nbytes())); + copies.push_back(temp); + + launch_hadamard( + in_contiguous, + temp, + batch_size, + threads_per, + "n" + kernel_name, + 1.0, + s); + + // Metal sometimes reports 256 max threads per group for hadamard_m kernel + threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP); + batch_size = in.size() / m / read_width / threads_per; + launch_hadamard( + temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s); + } else { + launch_hadamard( + in_contiguous, + out, + batch_size, + threads_per, + "n" + kernel_name, + scale_, + s); + } + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index f7e25c7c1..aca2c683e 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -18,6 +18,7 @@ const char* binary(); const char* binary_two(); const char* copy(); const char* fft(); +const char* hadamard(); const char* quantized(); const char* ternary(); const char* scan(); diff --git a/mlx/backend/metal/kernels/hadamard.h b/mlx/backend/metal/kernels/hadamard.h new file mode 100644 index 000000000..da4050cf4 --- /dev/null +++ b/mlx/backend/metal/kernels/hadamard.h @@ -0,0 +1,167 @@ +// Copyright © 2024 Apple Inc. +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" + +using namespace metal; + +// Thread local Hadamard transform for 2^R +template +METAL_FUNC void radix_func(thread float* x) { + constexpr short logR = __builtin_ctz(R); + short h = 1; + STEEL_PRAGMA_UNROLL + for (short s = 0; s < logR; s++) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < R / 2; i++) { + short k = i & (h - 1); + short j = ((i - k) << 1) + k; + float a = x[j]; + float b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + h <<= 1; + } +} + +template +[[kernel]] void hadamard_n( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute a Hadamard transform of size N = 2^k + // + // Equivalent to: + // from scipy.linalg import hadamard + // y = hadamard(len(x)) @ x + + constexpr short num_threads = N / max_radix; + constexpr short logN = __builtin_ctz(N); + constexpr short logR = __builtin_ctz(max_radix); + constexpr short num_steps = logN / logR; + constexpr short logFinal = logN % logR; + constexpr short final_radix = 1 << (logFinal); + + int batch_idx = elem.x * N; + short i = elem.y; + + threadgroup T buf[N]; + + // Read values from device + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float x[max_radix]; + short h = 1; + + STEEL_PRAGMA_UNROLL + for (short s = 0; s < num_steps; s++) { + short k = i & (h - 1); + short j = ((i - k) << logR) + k; + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < max_radix; r++) { + x[r] = buf[j + h * r]; + } + + radix_func(x); + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < max_radix; r++) { + buf[j + h * r] = x[r]; + } + + h <<= logR; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Do the final radix + // e.g. max_radix = 16 + // N = 1024 = 16 * 16 * 4 + if (final_radix > 1) { + // Each thread does multiple butterflies + STEEL_PRAGMA_UNROLL + for (int t = 0; t < max_radix / final_radix; t++) { + short index = i + t * num_threads; + short k = index & (h - 1); + short j = ((index - k) << logFinal) + k; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < final_radix; r++) { + x[r] = buf[j + h * r]; + } + + radix_func(x); + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < final_radix; r++) { + buf[j + h * r] = x[r]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write values to device + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = buf[index + r] * scale; + } + } +} + +template +[[kernel]] void hadamard_m( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute a Hadamard transform of size M + // using a naive O(M^2) codelet. + // + // This kernel is the second stage in the computation + // of a Hadamard transform of size M*N where N = 2^k. + + int index = elem.x * grid.y + elem.y; + short i = index % (N / read_width); + int batch_idx = index / (N / read_width) * M * N; + + float x[read_width][M]; + STEEL_PRAGMA_UNROLL + for (short c = 0; c < M; c++) { + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + x[r][c] = in[batch_idx + c * N + i * read_width + r]; + } + } + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + // This function is JIT compiled for M + // using the Hadamard matrix strings in `metal/hadamard.cpp` + hadamard_radix_m(x[r]); + } + + // Write back to device + STEEL_PRAGMA_UNROLL + for (short c = 0; c < M; c++) { + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale; + } + } +} diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 6ecb0a095..fc43079b4 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -130,17 +130,6 @@ inline void debug_set_primitive_buffer_label( #endif } -bool is_power_of_2(int n) { - return ((n & (n - 1)) == 0) && n != 0; -} - -int next_power_of_2(int n) { - if (is_power_of_2(n)) { - return n; - } - return pow(2, std::ceil(std::log2(n))); -} - std::string get_primitive_string(Primitive* primitive) { std::ostringstream op_t; primitive->print(op_t); diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index f652a556c..60c731930 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -61,6 +61,7 @@ NO_CPU(GatherMM) NO_CPU(GatherQMM) NO_CPU(Greater) NO_CPU(GreaterEqual) +NO_CPU(Hadamard) NO_CPU(Less) NO_CPU(LessEqual) NO_CPU(Load) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index a5dd87369..7dce75f7d 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -62,6 +62,7 @@ NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Greater) NO_GPU(GreaterEqual) +NO_GPU(Hadamard) NO_GPU(Less) NO_GPU(LessEqual) NO_GPU(Load) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 5d019edf0..f9051e243 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -451,6 +451,18 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) { return flatten(a, 0, a.ndim() - 1, s); } +array hadamard_transform( + const array& a, + float scale /* = 1.0 */, + StreamOrDevice s /* = {} */) { + auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; + return array( + a.shape(), + dtype, + std::make_shared(to_stream(s), scale), + {astype(a, dtype, s)}); +} + array squeeze( const array& a, const std::vector& axes, diff --git a/mlx/ops.h b/mlx/ops.h index 069400ba8..fb07bf1fa 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -131,6 +131,12 @@ array flatten( /** Flatten the array to 1D. */ array flatten(const array& a, StreamOrDevice s = {}); +/** Multiply the array by the Hadamard matrix of corresponding size. */ +array hadamard_transform( + const array& a, + float scale = 1.0f, + StreamOrDevice s = {}); + /** Remove singleton dimensions at the given axes. */ array squeeze( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 94136c6a5..8b2833ad4 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3976,4 +3976,42 @@ bool View::is_equivalent(const Primitive& other) const { return (dtype_ == a_other.dtype_); } +std::pair, std::vector> Hadamard::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + auto& s = stream(); + if (axes[0] == inputs[0].ndim() - 1) { + auto a = moveaxis(inputs[0], axes[0], 0, s); + auto b = hadamard_transform(a, scale_, s); + return {{b}, {0}}; + } + return {{hadamard_transform(inputs[0], scale_, s)}, axes}; +} + +std::vector Hadamard::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return jvp(primals, cotangents, argnums); +} + +std::vector Hadamard::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {hadamard_transform(tangents[0], scale_, stream())}; +} + +bool Hadamard::is_equivalent(const Primitive& other) const { + const Hadamard& h_other = static_cast(other); + return scale_ == h_other.scale_; +} + } // namespace mlx::core diff --git a/mlx/primitives.h b/mlx/primitives.h index 4bd3b421d..4255cd4b1 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1064,6 +1064,27 @@ class GreaterEqual : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Hadamard : public UnaryPrimitive { + public: + explicit Hadamard(Stream stream, float scale) + : UnaryPrimitive(stream), scale_(scale) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(Hadamard) + DEFINE_INPUT_OUTPUT_SHAPE() + + bool is_equivalent(const Primitive& other) const override; + + private: + float scale_; + + void eval(const std::vector& inputs, array& out); +}; + class Less : public UnaryPrimitive { public: explicit Less(Stream stream) : UnaryPrimitive(stream) {} diff --git a/mlx/utils.h b/mlx/utils.h index a86db4009..ee2512554 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -118,4 +118,16 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } + +inline bool is_power_of_2(int n) { + return ((n & (n - 1)) == 0) && n != 0; +} + +inline int next_power_of_2(int n) { + if (is_power_of_2(n)) { + return n; + } + return pow(2, std::ceil(std::log2(n))); +} + } // namespace mlx::core diff --git a/python/src/ops.cpp b/python/src/ops.cpp index fb7baa2ec..e2412b44d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4372,6 +4372,35 @@ void init_ops(nb::module_& m) { a (array): Input array or scalar. dtype (Dtype): The data type to change to. + Returns: + array: The array with the new type. + )pbdoc"); + m.def( + "hadamard_transform", + &hadamard_transform, + nb::arg(), + "scale"_a = 1.0, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Perform the Walsh-Hadamard transform along the final axis. + + Equivalent to: + ```python + from scipy.linalg import hadamard + + y = hadamard(len(x)) @ x + ``` + + Supports sizes `n = m*2^k` where m in (1, 12, 20, 28) + and 2^k <= 8192 for FP32 and 2^k <= 16384 for FP16/BF16. + + Args: + a (array): Input array or scalar. + scale (float): Scale the output by this factor. + Returns: array: The array with the new type. )pbdoc"); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 9586245e6..4b0abebdb 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2425,6 +2425,104 @@ class TestOps(mlx_tests.MLXTestCase): a_out = out.view(mx.int32) self.assertTrue(mx.array_equal(a_out, a, equal_nan=True)) + def _hadamard(self, N): + # Matches scipy.linalg.hadamard + H = np.array([[1]], dtype=np.int64) + for i in range(0, np.log2(N).astype(np.int64)): + H = np.vstack((np.hstack((H, H)), np.hstack((H, -H)))) + return H + + def test_hadamard(self): + h28_str = """ + +------++----++-+--+-+--++-- + -+-----+++-----+-+--+-+--++- + --+-----+++---+-+-+----+--++ + ---+-----+++---+-+-+-+--+--+ + ----+-----+++---+-+-+++--+-- + -----+-----++++--+-+--++--+- + ------++----++-+--+-+--++--+ + --++++-+-------++--+++-+--+- + ---++++-+-----+-++--+-+-+--+ + +---+++--+----++-++--+-+-+-- + ++---++---+----++-++--+-+-+- + +++---+----+----++-++--+-+-+ + ++++--------+-+--++-++--+-+- + -++++--------+++--++--+--+-+ + -+-++-++--++--+--------++++- + +-+-++--+--++--+--------++++ + -+-+-++--+--++--+----+---+++ + +-+-+-++--+--+---+---++---++ + ++-+-+-++--+------+--+++---+ + -++-+-+-++--+------+-++++--- + +-++-+---++--+------+-++++-- + -++--++-+-++-+++----++------ + +-++--++-+-++-+++-----+----- + ++-++---+-+-++-+++-----+---- + -++-++-+-+-+-+--+++-----+--- + --++-++++-+-+----+++-----+-- + +--++-+-++-+-+----+++-----+- + ++--++-+-++-+-+----++------+ + """ + + def parse_h_string(h_str): + return np.array( + [[1 if s == "+" else -1 for s in row] for row in h_str.split()] + ) + + h28 = parse_h_string(h28_str) + + np.random.seed(7) + tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15)) + for dtype, m, k in tests: + # skip large m=28 cases because they're very slow in NumPy + if (m > 1 and k > 8) or (dtype != np.float16 and k == 14): + continue + with self.subTest(dtype=dtype, m=m, k=k): + n = m * 2**k + b = 4 + scale = 0.34 + x = np.random.normal(size=(b, n)).astype(dtype) + # contiguity check + x = mx.array(x)[::2] + y = mx.hadamard_transform(x, scale=scale) + mx.eval(y) + h = ( + self._hadamard(2**k) + if m == 1 + else np.kron(h28, self._hadamard(2**k)) + ) + y_np = np.einsum("ij,bj->bi", h, x) * scale + atol = 2e-4 if dtype == np.float32 else 5e-2 * k + np.testing.assert_allclose(y, y_np, atol=atol) + + def test_hadamard_grad_vmap(self): + np.random.seed(4) + + for k in range(2, 8): + n = 2**k + x = np.random.normal(size=(n,)) + h = self._hadamard(n) + c = np.random.normal(size=(n,)) + x = mx.array(x).astype(mx.float32) + h = mx.array(h).astype(mx.float32) + c = mx.array(c).astype(mx.float32) + + def hadamard_transform(x): + return h @ x + + out = mx.vjp(hadamard_transform, [x], [c]) + out_t = mx.vjp(mx.hadamard_transform, [x], [c]) + np.testing.assert_allclose(out, out_t, atol=1e-4) + + for axis in (0, 1, 2): + vht = mx.vmap(mx.vmap(hadamard_transform, 0, 0), axis, axis) + vht_t = mx.vmap(mx.vmap(mx.hadamard_transform, 0, 0), axis, axis) + + xb = mx.array(np.random.normal(size=(n, n, n))) + out = vht(xb) + out_t = vht_t(xb) + np.testing.assert_allclose(out, out_t, atol=1e-4) + if __name__ == "__main__": unittest.main()