Fast Hadamard Transform (#1249)

* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
This commit is contained in:
Alex Barron
2024-07-09 20:39:01 -07:00
committed by GitHub
parent 03cf033f82
commit a3c287354f
22 changed files with 878 additions and 11 deletions

View File

@@ -52,6 +52,7 @@ make_jit_source(
)
make_jit_source(scatter)
make_jit_source(gather)
make_jit_source(hadamard)
if (MLX_METAL_JIT)
target_sources(
@@ -132,6 +133,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp

View File

@@ -14,6 +14,7 @@
#include "mlx/backend/metal/utils.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {

View File

@@ -0,0 +1,203 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M
// using the hadamard matrices above
//
// e.g. m = 2
// METAL_FUNC void hadamard_m(thread float *x) {
// float tmp[2];
// tmp[0] = + x[0] + x[1];
// tmp[1] = + x[0] - x[1];
// for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }
// }
//
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
std::ostringstream source;
source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl;
if (m == 1) {
source << "}" << std::endl;
return source.str();
}
source << " float tmp[" << m << "];" << std::endl;
auto start = 1;
auto end = matrix.find('\n', start);
int index = 0;
while (end != std::string_view::npos) {
source << " tmp[" << index << "] = ";
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
source << " " << row[i] << " x[" << i << "]";
}
source << ";" << std::endl;
start = end + 1;
end = matrix.find('\n', start);
index++;
}
source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }"
<< std::endl;
source << "}" << std::endl;
return source.str();
}
void launch_hadamard(
const array& in,
array& out,
int batch_size,
int threads_per,
const std::string kernel_name,
float scale,
const Stream& s) {
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& in = inputs[0];
std::vector<array> copies;
// Only support the last axis for now
int axis = in.ndim() - 1;
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint
bool no_copy = x.flags().row_contiguous;
if (no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.move_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto [n, m] = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
lib = d.get_library(lib_name, kernel_source.str());
}
int batch_size = in.size() / n;
int threads_per = n / max_radix;
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(
in_contiguous,
temp,
batch_size,
threads_per,
"n" + kernel_name,
1.0,
s);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
} else {
launch_hadamard(
in_contiguous,
out,
batch_size,
threads_per,
"n" + kernel_name,
scale_,
s);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

View File

@@ -18,6 +18,7 @@ const char* binary();
const char* binary_two();
const char* copy();
const char* fft();
const char* hadamard();
const char* quantized();
const char* ternary();
const char* scan();

View File

@@ -0,0 +1,167 @@
// Copyright © 2024 Apple Inc.
#include <metal_common>
#include <metal_compute>
#include "mlx/backend/metal/kernels/steel/defines.h"
using namespace metal;
// Thread local Hadamard transform for 2^R
template <short R>
METAL_FUNC void radix_func(thread float* x) {
constexpr short logR = __builtin_ctz(R);
short h = 1;
STEEL_PRAGMA_UNROLL
for (short s = 0; s < logR; s++) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < R / 2; i++) {
short k = i & (h - 1);
short j = ((i - k) << 1) + k;
float a = x[j];
float b = x[j + h];
x[j] = a + b;
x[j + h] = a - b;
}
h <<= 1;
}
}
template <typename T, int N, int max_radix, int read_width>
[[kernel]] void hadamard_n(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute a Hadamard transform of size N = 2^k
//
// Equivalent to:
// from scipy.linalg import hadamard
// y = hadamard(len(x)) @ x
constexpr short num_threads = N / max_radix;
constexpr short logN = __builtin_ctz(N);
constexpr short logR = __builtin_ctz(max_radix);
constexpr short num_steps = logN / logR;
constexpr short logFinal = logN % logR;
constexpr short final_radix = 1 << (logFinal);
int batch_idx = elem.x * N;
short i = elem.y;
threadgroup T buf[N];
// Read values from device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float x[max_radix];
short h = 1;
STEEL_PRAGMA_UNROLL
for (short s = 0; s < num_steps; s++) {
short k = i & (h - 1);
short j = ((i - k) << logR) + k;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < max_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<max_radix>(x);
STEEL_PRAGMA_UNROLL
for (short r = 0; r < max_radix; r++) {
buf[j + h * r] = x[r];
}
h <<= logR;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Do the final radix
// e.g. max_radix = 16
// N = 1024 = 16 * 16 * 4
if (final_radix > 1) {
// Each thread does multiple butterflies
STEEL_PRAGMA_UNROLL
for (int t = 0; t < max_radix / final_radix; t++) {
short index = i + t * num_threads;
short k = index & (h - 1);
short j = ((index - k) << logFinal) + k;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < final_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<final_radix>(x);
STEEL_PRAGMA_UNROLL
for (short r = 0; r < final_radix; r++) {
buf[j + h * r] = x[r];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Write values to device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = buf[index + r] * scale;
}
}
}
template <typename T, int N, int M, int read_width>
[[kernel]] void hadamard_m(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute a Hadamard transform of size M
// using a naive O(M^2) codelet.
//
// This kernel is the second stage in the computation
// of a Hadamard transform of size M*N where N = 2^k.
int index = elem.x * grid.y + elem.y;
short i = index % (N / read_width);
int batch_idx = index / (N / read_width) * M * N;
float x[read_width][M];
STEEL_PRAGMA_UNROLL
for (short c = 0; c < M; c++) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
x[r][c] = in[batch_idx + c * N + i * read_width + r];
}
}
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
// This function is JIT compiled for M
// using the Hadamard matrix strings in `metal/hadamard.cpp`
hadamard_radix_m(x[r]);
}
// Write back to device
STEEL_PRAGMA_UNROLL
for (short c = 0; c < M; c++) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
}
}
}

View File

@@ -130,17 +130,6 @@ inline void debug_set_primitive_buffer_label(
#endif
}
bool is_power_of_2(int n) {
return ((n & (n - 1)) == 0) && n != 0;
}
int next_power_of_2(int n) {
if (is_power_of_2(n)) {
return n;
}
return pow(2, std::ceil(std::log2(n)));
}
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);