mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Metal FFT for powers of 2 up to 2048 (#915)
* add Metal FFT for powers of 2 * skip GPU test on linux * fix contiguity bug * address comments * Update mlx/backend/metal/fft.cpp * Update mlx/backend/metal/fft.cpp * fix bug in synch --------- Co-authored-by: Alex Barron <abarron22@apple.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,12 +1,106 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& in = inputs[0];
|
||||
throw std::runtime_error("[FFT] NYI for Metal backend.");
|
||||
|
||||
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
|
||||
in.dtype() != complex64 || out.dtype() != complex64) {
|
||||
// Could also fallback to CPU implementation here.
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
|
||||
}
|
||||
|
||||
size_t n = in.shape(axes_[0]);
|
||||
|
||||
if (!is_power_of_2(n) || n > 2048 || n < 4) {
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
|
||||
}
|
||||
|
||||
// Make sure that the array is contiguous and has stride 1 in the FFT dim
|
||||
std::vector<array> copies;
|
||||
auto check_input = [this, &copies, &s](const array& x) {
|
||||
// TODO: Pass the strides to the kernel so
|
||||
// we can avoid the copy when x is not contiguous.
|
||||
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
|
||||
x.flags().col_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axes_[0]);
|
||||
for (int axis = 0; axis < x.ndim(); axis++) {
|
||||
if (axis == axes_[0]) {
|
||||
strides.push_back(1);
|
||||
} else {
|
||||
strides.push_back(cur_stride);
|
||||
cur_stride *= x.shape(axis);
|
||||
}
|
||||
}
|
||||
|
||||
auto flags = x.flags();
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
|
||||
f_stride *= x.shape(i);
|
||||
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
|
||||
b_stride *= x.shape(ri);
|
||||
}
|
||||
// This is probably over-conservative
|
||||
flags.contiguous = false;
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(inputs[0]);
|
||||
|
||||
// TODO: allow donation here
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
in_contiguous.data_size(),
|
||||
in_contiguous.strides(),
|
||||
in_contiguous.flags());
|
||||
|
||||
// We use n / 4 threads by default since radix-4
|
||||
// is the largest single threaded radix butterfly
|
||||
// we currently implement.
|
||||
size_t m = n / 4;
|
||||
size_t batch = in.size() / in.shape(axes_[0]);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "fft_" << n;
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -21,6 +21,7 @@ set(
|
||||
"binary_two"
|
||||
"conv"
|
||||
"copy"
|
||||
"fft"
|
||||
"gemv"
|
||||
"quantized"
|
||||
"random"
|
||||
|
195
mlx/backend/metal/kernels/fft.metal
Normal file
195
mlx/backend/metal/kernels/fft.metal
Normal file
@@ -0,0 +1,195 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Metal FFT using Stockham's algorithm
|
||||
//
|
||||
// References:
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
|
||||
#include <metal_math>
|
||||
#include <metal_common>
|
||||
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
float2 complex_mul(float2 a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a.x * b.x - a.y * b.y;
|
||||
c.y = a.x * b.y + a.y * b.x;
|
||||
return c;
|
||||
}
|
||||
|
||||
float2 get_twiddle(int k, int p) {
|
||||
float theta = -1.0f * k * M_PI_F / (2*p);
|
||||
|
||||
float2 twiddle;
|
||||
twiddle.x = metal::fast::cos(theta);
|
||||
twiddle.y = metal::fast::sin(theta);
|
||||
return twiddle;
|
||||
}
|
||||
|
||||
// single threaded radix2 implemetation
|
||||
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
|
||||
float2 z = complex_mul(x_1, twiddle);
|
||||
|
||||
float2 y_0 = x_0 + z;
|
||||
float2 y_1 = x_0 - z;
|
||||
|
||||
int j = (i << 1) - k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
}
|
||||
|
||||
// single threaded radix4 implemetation
|
||||
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
float2 x_2 = read_buf[i + 2*m];
|
||||
float2 x_3 = read_buf[i + 3*m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
// e^a * e^b = e^(a + b)
|
||||
float2 twiddle_2 = complex_mul(twiddle, twiddle);
|
||||
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
|
||||
|
||||
x_1 = complex_mul(x_1, twiddle);
|
||||
x_2 = complex_mul(x_2, twiddle_2);
|
||||
x_3 = complex_mul(x_3, twiddle_3);
|
||||
|
||||
float2 minus_i;
|
||||
minus_i.x = 0;
|
||||
minus_i.y = -1;
|
||||
|
||||
// Hard coded twiddle factors for DFT4
|
||||
float2 z_0 = x_0 + x_2;
|
||||
float2 z_1 = x_0 - x_2;
|
||||
float2 z_2 = x_1 + x_3;
|
||||
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
|
||||
|
||||
float2 y_0 = z_0 + z_2;
|
||||
float2 y_1 = z_1 + z_3;
|
||||
float2 y_2 = z_0 - z_2;
|
||||
float2 y_3 = z_1 - z_3;
|
||||
|
||||
int j = ((i - k) << 2) + k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
write_buf[j + 2*p] = y_2;
|
||||
write_buf[j + 3*p] = y_3;
|
||||
}
|
||||
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-2 and radix-4 DFTs:
|
||||
// e.g. 128 = 2 * 4 * 4 * 4
|
||||
//
|
||||
// At each step we use n / 4 threads, each performing
|
||||
// a single-threaded radix-4 or radix-2 DFT.
|
||||
//
|
||||
// We provide the number of radix-2 and radix-4
|
||||
// steps at compile time for a ~20% performance boost.
|
||||
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||
[[kernel]] void fft(
|
||||
const device float2 *in [[buffer(0)]],
|
||||
device float2 * out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
||||
uint3 threads_per_grid [[threads_per_grid]]) {
|
||||
|
||||
// Index of the DFT in batch
|
||||
int batch_idx = thread_position_in_grid.x * n;
|
||||
// The index in the DFT we're working on
|
||||
int i = thread_position_in_grid.y;
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = threads_per_grid.y;
|
||||
|
||||
// Allocate 2 shared memory buffers for Stockham.
|
||||
// We alternate reading from one and writing to the other at each radix step.
|
||||
threadgroup float2 shared_in[n];
|
||||
threadgroup float2 shared_out[n];
|
||||
|
||||
// Pointers to facilitate Stockham buffer swapping
|
||||
threadgroup float2* read_buf = shared_in;
|
||||
threadgroup float2* write_buf = shared_out;
|
||||
threadgroup float2* tmp;
|
||||
|
||||
// Copy input into shared memory
|
||||
shared_in[i] = in[batch_idx + i];
|
||||
shared_in[i + m] = in[batch_idx + i + m];
|
||||
shared_in[i + 2*m] = in[batch_idx + i + 2*m];
|
||||
shared_in[i + 3*m] = in[batch_idx + i + 3*m];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
|
||||
for (size_t r = 0; r < radix_2_steps; r++) {
|
||||
radix2(i, p, m*2, read_buf, write_buf);
|
||||
radix2(i + m, p, m*2, read_buf, write_buf);
|
||||
p *= 2;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
for (size_t r = 0; r < radix_4_steps; r++) {
|
||||
radix4(i, p, m, read_buf, write_buf);
|
||||
p *= 4;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
// Copy shared memory to output
|
||||
out[batch_idx + i] = read_buf[i];
|
||||
out[batch_idx + i + m] = read_buf[i + m];
|
||||
out[batch_idx + i + 2*m] = read_buf[i + 2*m];
|
||||
out[batch_idx + i + 3*m] = read_buf[i + 3*m];
|
||||
}
|
||||
|
||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||
template [[host_name("fft_" #name)]] \
|
||||
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
|
||||
const device float2* in [[buffer(0)]], \
|
||||
device float2* out [[buffer(1)]], \
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||
uint3 threads_per_grid [[threads_per_grid]]);
|
||||
|
||||
|
||||
// Explicitly define kernels for each power of 2.
|
||||
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
||||
instantiate_fft(8, 8, 1, 1)
|
||||
instantiate_fft(16, 16, 0, 2)
|
||||
instantiate_fft(32, 32, 1, 2)
|
||||
instantiate_fft(64, 64, 0, 3)
|
||||
instantiate_fft(128, 128, 1, 3)
|
||||
instantiate_fft(256, 256, 0, 4)
|
||||
instantiate_fft(512, 512, 1, 4)
|
||||
instantiate_fft(1024, 1024, 0, 5)
|
||||
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
||||
// TODO: implement 4 step FFT for larger n.
|
||||
instantiate_fft(2048, 2048, 1, 5)
|
@@ -130,6 +130,10 @@ inline void debug_set_primitive_buffer_label(
|
||||
#endif
|
||||
}
|
||||
|
||||
bool is_power_of_2(int n) {
|
||||
return ((n & (n - 1)) == 0) && n != 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user