mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2 * working for m*2^k * add scale and check contiguity * add size check * clean up * fix test * add grads + vmap * gpu only * skip on linux * test typo * add cpu impl * remove gpu only tests * fix linux build + add is_equivalent
This commit is contained in:
@@ -52,6 +52,7 @@ make_jit_source(
|
||||
)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if (MLX_METAL_JIT)
|
||||
target_sources(
|
||||
@@ -132,6 +133,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
|
@@ -14,6 +14,7 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
|
203
mlx/backend/metal/hadamard.cpp
Normal file
203
mlx/backend/metal/hadamard.cpp
Normal file
@@ -0,0 +1,203 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
|
||||
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
|
||||
|
||||
std::string gen_hadamard_codelet(int m) {
|
||||
// Generate a O(m^2) hadamard codelet for a given M
|
||||
// using the hadamard matrices above
|
||||
//
|
||||
// e.g. m = 2
|
||||
// METAL_FUNC void hadamard_m(thread float *x) {
|
||||
// float tmp[2];
|
||||
// tmp[0] = + x[0] + x[1];
|
||||
// tmp[1] = + x[0] - x[1];
|
||||
// for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }
|
||||
// }
|
||||
//
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
|
||||
std::ostringstream source;
|
||||
source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl;
|
||||
if (m == 1) {
|
||||
source << "}" << std::endl;
|
||||
return source.str();
|
||||
}
|
||||
source << " float tmp[" << m << "];" << std::endl;
|
||||
auto start = 1;
|
||||
auto end = matrix.find('\n', start);
|
||||
|
||||
int index = 0;
|
||||
while (end != std::string_view::npos) {
|
||||
source << " tmp[" << index << "] = ";
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
source << " " << row[i] << " x[" << i << "]";
|
||||
}
|
||||
source << ";" << std::endl;
|
||||
start = end + 1;
|
||||
end = matrix.find('\n', start);
|
||||
index++;
|
||||
}
|
||||
source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }"
|
||||
<< std::endl;
|
||||
source << "}" << std::endl;
|
||||
return source.str();
|
||||
}
|
||||
|
||||
void launch_hadamard(
|
||||
const array& in,
|
||||
array& out,
|
||||
int batch_size,
|
||||
int threads_per,
|
||||
const std::string kernel_name,
|
||||
float scale,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
const auto& lib_name = kernel_name.substr(1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
std::vector<array> copies;
|
||||
// Only support the last axis for now
|
||||
int axis = in.ndim() - 1;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
// TODO(alexbarron) pass strides to kernel to relax this constraint
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.move_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto [n, m] = decompose_hadamard(in.shape(axis));
|
||||
|
||||
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
|
||||
}
|
||||
|
||||
int max_radix = std::min(n, 16);
|
||||
// Use read_width 2 for m = 28 to avoid register spilling
|
||||
int read_width = (n == 2 || m == 28) ? 2 : 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "hadamard_" << n * m << "_" << type_to_name(out);
|
||||
auto kernel_name = kname.str();
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto codelet = gen_hadamard_codelet(m);
|
||||
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||
kernel_source << get_template_definition(
|
||||
"n" + kernel_name,
|
||||
"hadamard_n",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
max_radix,
|
||||
read_width);
|
||||
kernel_source << get_template_definition(
|
||||
"m" + kernel_name,
|
||||
"hadamard_m",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
|
||||
int batch_size = in.size() / n;
|
||||
int threads_per = n / max_radix;
|
||||
|
||||
if (m > 1) {
|
||||
// When m is greater than 1, we decompose the
|
||||
// computation into two uploads to the GPU:
|
||||
//
|
||||
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
|
||||
//
|
||||
// y = h48 @ x
|
||||
//
|
||||
// Upload 1:
|
||||
// tmp = a.reshape(12, 4) @ h4
|
||||
//
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(
|
||||
in_contiguous,
|
||||
temp,
|
||||
batch_size,
|
||||
threads_per,
|
||||
"n" + kernel_name,
|
||||
1.0,
|
||||
s);
|
||||
|
||||
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||
batch_size = in.size() / m / read_width / threads_per;
|
||||
launch_hadamard(
|
||||
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
|
||||
} else {
|
||||
launch_hadamard(
|
||||
in_contiguous,
|
||||
out,
|
||||
batch_size,
|
||||
threads_per,
|
||||
"n" + kernel_name,
|
||||
scale_,
|
||||
s);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -18,6 +18,7 @@ const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* fft();
|
||||
const char* hadamard();
|
||||
const char* quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
|
167
mlx/backend/metal/kernels/hadamard.h
Normal file
167
mlx/backend/metal/kernels/hadamard.h
Normal file
@@ -0,0 +1,167 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <metal_common>
|
||||
#include <metal_compute>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Thread local Hadamard transform for 2^R
|
||||
template <short R>
|
||||
METAL_FUNC void radix_func(thread float* x) {
|
||||
constexpr short logR = __builtin_ctz(R);
|
||||
short h = 1;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short s = 0; s < logR; s++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < R / 2; i++) {
|
||||
short k = i & (h - 1);
|
||||
short j = ((i - k) << 1) + k;
|
||||
float a = x[j];
|
||||
float b = x[j + h];
|
||||
x[j] = a + b;
|
||||
x[j + h] = a - b;
|
||||
}
|
||||
h <<= 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N, int max_radix, int read_width>
|
||||
[[kernel]] void hadamard_n(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const float& scale,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Compute a Hadamard transform of size N = 2^k
|
||||
//
|
||||
// Equivalent to:
|
||||
// from scipy.linalg import hadamard
|
||||
// y = hadamard(len(x)) @ x
|
||||
|
||||
constexpr short num_threads = N / max_radix;
|
||||
constexpr short logN = __builtin_ctz(N);
|
||||
constexpr short logR = __builtin_ctz(max_radix);
|
||||
constexpr short num_steps = logN / logR;
|
||||
constexpr short logFinal = logN % logR;
|
||||
constexpr short final_radix = 1 << (logFinal);
|
||||
|
||||
int batch_idx = elem.x * N;
|
||||
short i = elem.y;
|
||||
|
||||
threadgroup T buf[N];
|
||||
|
||||
// Read values from device
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
buf[index + r] = in[batch_idx + index + r];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float x[max_radix];
|
||||
short h = 1;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short s = 0; s < num_steps; s++) {
|
||||
short k = i & (h - 1);
|
||||
short j = ((i - k) << logR) + k;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < max_radix; r++) {
|
||||
x[r] = buf[j + h * r];
|
||||
}
|
||||
|
||||
radix_func<max_radix>(x);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < max_radix; r++) {
|
||||
buf[j + h * r] = x[r];
|
||||
}
|
||||
|
||||
h <<= logR;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// Do the final radix
|
||||
// e.g. max_radix = 16
|
||||
// N = 1024 = 16 * 16 * 4
|
||||
if (final_radix > 1) {
|
||||
// Each thread does multiple butterflies
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int t = 0; t < max_radix / final_radix; t++) {
|
||||
short index = i + t * num_threads;
|
||||
short k = index & (h - 1);
|
||||
short j = ((index - k) << logFinal) + k;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < final_radix; r++) {
|
||||
x[r] = buf[j + h * r];
|
||||
}
|
||||
|
||||
radix_func<final_radix>(x);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < final_radix; r++) {
|
||||
buf[j + h * r] = x[r];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// Write values to device
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + index + r] = buf[index + r] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N, int M, int read_width>
|
||||
[[kernel]] void hadamard_m(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const float& scale,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Compute a Hadamard transform of size M
|
||||
// using a naive O(M^2) codelet.
|
||||
//
|
||||
// This kernel is the second stage in the computation
|
||||
// of a Hadamard transform of size M*N where N = 2^k.
|
||||
|
||||
int index = elem.x * grid.y + elem.y;
|
||||
short i = index % (N / read_width);
|
||||
int batch_idx = index / (N / read_width) * M * N;
|
||||
|
||||
float x[read_width][M];
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short c = 0; c < M; c++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
x[r][c] = in[batch_idx + c * N + i * read_width + r];
|
||||
}
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
// This function is JIT compiled for M
|
||||
// using the Hadamard matrix strings in `metal/hadamard.cpp`
|
||||
hadamard_radix_m(x[r]);
|
||||
}
|
||||
|
||||
// Write back to device
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short c = 0; c < M; c++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
|
||||
}
|
||||
}
|
||||
}
|
@@ -130,17 +130,6 @@ inline void debug_set_primitive_buffer_label(
|
||||
#endif
|
||||
}
|
||||
|
||||
bool is_power_of_2(int n) {
|
||||
return ((n & (n - 1)) == 0) && n != 0;
|
||||
}
|
||||
|
||||
int next_power_of_2(int n) {
|
||||
if (is_power_of_2(n)) {
|
||||
return n;
|
||||
}
|
||||
return pow(2, std::ceil(std::log2(n)));
|
||||
}
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive) {
|
||||
std::ostringstream op_t;
|
||||
primitive->print(op_t);
|
||||
|
Reference in New Issue
Block a user