mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
GPU Hadamard for large N (#1879)
This commit is contained in:
parent
9daa6b003f
commit
481349495b
@ -99,6 +99,10 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
if (n > (1 << 26)) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@ -15,7 +13,6 @@
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
|
||||
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
|
||||
|
||||
std::string gen_hadamard_codelet(int m) {
|
||||
// Generate a O(m^2) hadamard codelet for a given M
|
||||
@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) {
|
||||
return source.str();
|
||||
}
|
||||
|
||||
void hadamard_mn_contiguous(
|
||||
const array& x,
|
||||
array& y,
|
||||
int m,
|
||||
int n1,
|
||||
int n2,
|
||||
float scale,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
int n = n1 * n2;
|
||||
int read_width_n1 = n1 == 2 ? 2 : 4;
|
||||
int read_width_n2 = n2 == 2 ? 2 : 4;
|
||||
int read_width_m = (n == 2 || m == 28) ? 2 : 4;
|
||||
int max_radix_1 = std::min(n1, 16);
|
||||
int max_radix_2 = std::min(n2, 16);
|
||||
float scale_n1 = 1.0;
|
||||
float scale_n2 = (m == 1) ? scale : 1.0;
|
||||
float scale_m = scale;
|
||||
|
||||
// n2 is a row contiguous power of 2 hadamard transform
|
||||
MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);
|
||||
MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);
|
||||
|
||||
// n1 is a strided power of 2 hadamard transform with stride n2
|
||||
MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);
|
||||
MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);
|
||||
|
||||
// m is a strided hadamard transform with stride n = n1 * n2
|
||||
MTL::Size group_dims_m(
|
||||
std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);
|
||||
MTL::Size grid_dims_m(
|
||||
group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);
|
||||
|
||||
// Make the kernel
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(kname, "hadamard_", n * m, "_", type_to_name(x));
|
||||
auto lib = d.get_library(kname, [&]() {
|
||||
std::string kernel;
|
||||
concatenate(
|
||||
kernel,
|
||||
metal::utils(),
|
||||
gen_hadamard_codelet(m),
|
||||
metal::hadamard(),
|
||||
get_template_definition(
|
||||
"n2" + kname,
|
||||
"hadamard_n",
|
||||
get_type_string(x.dtype()),
|
||||
n2,
|
||||
max_radix_2,
|
||||
read_width_n2));
|
||||
if (n1 > 1) {
|
||||
kernel += get_template_definition(
|
||||
"n1" + kname,
|
||||
"hadamard_n",
|
||||
get_type_string(x.dtype()),
|
||||
n1,
|
||||
max_radix_1,
|
||||
read_width_n1,
|
||||
n2);
|
||||
}
|
||||
if (m > 1) {
|
||||
kernel += get_template_definition(
|
||||
"m" + kname,
|
||||
"hadamard_m",
|
||||
get_type_string(x.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width_m);
|
||||
}
|
||||
return kernel;
|
||||
});
|
||||
|
||||
// Launch the strided transform for n1
|
||||
if (n1 > 1) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel("n1" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_n1, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);
|
||||
}
|
||||
|
||||
// Launch the transform for n2
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel("n2" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(n1 > 1 ? y : x, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_n2, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);
|
||||
|
||||
// Launch the strided transform for m
|
||||
if (m > 1) {
|
||||
auto kernel = d.get_kernel("m" + kname, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(y, 0);
|
||||
compute_encoder.set_output_array(y, 1);
|
||||
compute_encoder.set_bytes(scale_m, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
std::vector<array> copies;
|
||||
// Only support the last axis for now
|
||||
int axis = in.ndim() - 1;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
// TODO(alexbarron) pass strides to kernel to relax this constraint
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
// Split the hadamard transform so that all of them work on vectors smaller
|
||||
// than 8192 elements.
|
||||
//
|
||||
// We decompose it in the following way:
|
||||
//
|
||||
// n = m * n1 * n2 = m * 2^k1 * 2^k2
|
||||
//
|
||||
// where m is in (1, 12, 20, 28) and n1 and n2 <= 8192
|
||||
auto [n, m] = decompose_hadamard(in.shape().back());
|
||||
int n1 = 1, n2 = n;
|
||||
if (n > 8192) {
|
||||
for (n2 = 2; n2 * n2 < n; n2 *= 2) {
|
||||
}
|
||||
n1 = n / n2;
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.copy_shared_buffer(in_contiguous);
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
int n, m;
|
||||
std::tie(n, m) = decompose_hadamard(in.shape(axis));
|
||||
|
||||
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
|
||||
}
|
||||
|
||||
int max_radix = std::min(n, 16);
|
||||
// Use read_width 2 for m = 28 to avoid register spilling
|
||||
int read_width = (n == 2 || m == 28) ? 2 : 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "hadamard_" << n * m << "_" << type_to_name(out);
|
||||
auto kernel_name = kname.str();
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
auto codelet = gen_hadamard_codelet(m);
|
||||
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||
kernel_source << get_template_definition(
|
||||
"n" + kernel_name,
|
||||
"hadamard_n",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
max_radix,
|
||||
read_width);
|
||||
kernel_source << get_template_definition(
|
||||
"m" + kernel_name,
|
||||
"hadamard_m",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width);
|
||||
return kernel_source.str();
|
||||
});
|
||||
|
||||
int batch_size = in.size() / n;
|
||||
int threads_per = n / max_radix;
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
auto launch_hadamard = [&](const array& in,
|
||||
array& out,
|
||||
const std::string& kernel_name,
|
||||
float scale) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(scale, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
};
|
||||
|
||||
if (m > 1) {
|
||||
// When m is greater than 1, we decompose the
|
||||
// computation into two uploads to the GPU:
|
||||
//
|
||||
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
|
||||
//
|
||||
// y = h48 @ x
|
||||
//
|
||||
// Upload 1:
|
||||
// tmp = a.reshape(12, 4) @ h4
|
||||
//
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||
|
||||
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||
batch_size = in.size() / m / read_width / threads_per;
|
||||
launch_hadamard(temp, out, "m" + kernel_name, scale_);
|
||||
hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);
|
||||
} else {
|
||||
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N, int max_radix, int read_width>
|
||||
template <typename T, int N, int max_radix, int read_width, int stride = 1>
|
||||
[[kernel]] void hadamard_n(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@ -46,12 +46,13 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
constexpr short logFinal = logN % logR;
|
||||
constexpr short final_radix = 1 << (logFinal);
|
||||
|
||||
int batch_idx = elem.x * N;
|
||||
short i = elem.y;
|
||||
int batch_idx = elem.y * N * stride + elem.z;
|
||||
short i = elem.x;
|
||||
|
||||
threadgroup T buf[N];
|
||||
|
||||
// Read values from device
|
||||
if (stride == 1) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
@ -60,6 +61,12 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
buf[index + r] = in[batch_idx + index + r];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix; j++) {
|
||||
buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@ -113,6 +120,7 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
}
|
||||
|
||||
// Write values to device
|
||||
if (stride == 1) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix / read_width; j++) {
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
@ -121,6 +129,13 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
out[batch_idx + index + r] = T(buf[index + r] * scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < max_radix; j++) {
|
||||
out[batch_idx + (j * num_threads + i) * stride] =
|
||||
buf[j * num_threads + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N, int M, int read_width>
|
||||
|
13
mlx/ops.cpp
13
mlx/ops.cpp
@ -473,8 +473,19 @@ array hadamard_transform(
|
||||
std::optional<float> scale_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
|
||||
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1));
|
||||
int n = a.ndim() > 0 ? a.shape(-1) : 1;
|
||||
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n);
|
||||
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
|
||||
|
||||
// Nothing to do for a scalar
|
||||
if (n == 1) {
|
||||
if (scale == 1) {
|
||||
return a;
|
||||
}
|
||||
|
||||
return multiply(a, array(scale, dtype), s);
|
||||
}
|
||||
|
||||
return array(
|
||||
a.shape(),
|
||||
dtype,
|
||||
|
@ -2868,11 +2868,33 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
h28 = parse_h_string(h28_str)
|
||||
|
||||
x = mx.array(5)
|
||||
y = mx.hadamard_transform(x)
|
||||
self.assertEqual(y.item(), 5)
|
||||
|
||||
x = mx.array(5)
|
||||
y = mx.hadamard_transform(x, scale=0.2)
|
||||
self.assertEqual(y.item(), 1)
|
||||
|
||||
x = mx.random.normal((8, 8, 1))
|
||||
y = mx.hadamard_transform(x)
|
||||
self.assertTrue(mx.all(y == x).item())
|
||||
|
||||
# Too slow to compare to numpy so let's compare CPU to GPU
|
||||
if mx.default_device() == mx.gpu:
|
||||
rk = mx.random.key(42)
|
||||
for k in range(14, 17):
|
||||
for m in [1, 3, 5, 7]:
|
||||
x = mx.random.normal((4, m * 2**k), key=rk)
|
||||
y1 = mx.hadamard_transform(x, stream=mx.cpu)
|
||||
y2 = mx.hadamard_transform(x, stream=mx.gpu)
|
||||
self.assertLess(mx.abs(y1 - y2).max().item(), 5e-6)
|
||||
|
||||
np.random.seed(7)
|
||||
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15))
|
||||
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 14))
|
||||
for dtype, m, k in tests:
|
||||
# skip large m=28 cases because they're very slow in NumPy
|
||||
if (m > 1 and k > 8) or (dtype != np.float16 and k == 14):
|
||||
if m > 1 and k > 8:
|
||||
continue
|
||||
with self.subTest(dtype=dtype, m=m, k=k):
|
||||
n = m * 2**k
|
||||
|
Loading…
Reference in New Issue
Block a user