Add single kernel bluestein

This commit is contained in:
Angelos Katharopoulos 2025-05-08 13:15:20 -07:00
parent 6593281d25
commit 2a41caa00e

View File

@ -122,6 +122,7 @@ class FFTPlan {
STOCKHAM,
RADER,
BLUESTEIN,
MULTIUPLOAD_BLUESTEIN,
SMALL_FOUR_STEP,
LARGE_FOUR_STEP
};
@ -132,7 +133,7 @@ class FFTPlan {
type_ = NOOP;
}
// Four step fft
// Too large for Stockham so do four step fft for powers of 2
else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
if (n <= 1 << 20) {
type_ = SMALL_FOUR_STEP;
@ -145,9 +146,9 @@ class FFTPlan {
}
}
// Bluestein fft
// Too large and not power of 2 so do multi-upload Bluestein fft
else if (n > MAX_STOCKHAM_FFT_SIZE) {
type_ = BLUESTEIN;
type_ = MULTIUPLOAD_BLUESTEIN;
bluestein_n_ = next_fast_n(2 * n - 1);
}
@ -157,9 +158,15 @@ class FFTPlan {
steps_ = steps;
}
// throw for now but we have rader and bluestein to do
else {
type_ = UNSUPPORTED;
// Add rader but for now simply fall back to bluestein when stockham not
// posssible
else if (n > MAX_BLUESTEIN_FFT_SIZE) {
type_ = MULTIUPLOAD_BLUESTEIN;
bluestein_n_ = next_fast_n(2 * n - 1);
} else {
type_ = BLUESTEIN;
bluestein_n_ = next_fast_n(2 * n - 1);
steps_ = stockham_decompose(bluestein_n_);
}
}
@ -191,6 +198,10 @@ class FFTPlan {
return steps2_;
}
int bluestein_size() const {
return bluestein_n_;
}
private:
int n_;
FFTType type_;
@ -1144,6 +1155,145 @@ void fft_four_step_inplace(
}
}
void fft_bluestein(
const FFTPlan& plan,
const array& in_,
array& out,
size_t axis,
bool inverse,
bool real,
metal::Device& d,
const Stream& s) {
// Prepare the input and output arrays such that `axis` has stride 1.
// Possibly copy the input but never the output as it doesn't have anything
// useful in it yet.
array in = ensure_fastest_moving_axis(in_, axis, d, s);
prepare_output_array(in, out, axis);
// Prepare the arguments for bluestein fft
int n = plan.bluestein_size();
bool power_of_2 = true;
int total_batch_size = out.dtype() == float32 ? out.size() / plan.size()
: in.size() / plan.size();
auto& steps = plan.steps();
int elems_per_thread = compute_elems_per_thread(n, steps);
int threads_per_fft = ceildiv(n, elems_per_thread);
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
int tg_mem_size = next_power_of_2(tg_batch_size * n);
int batch_size = ceildiv(total_batch_size, tg_batch_size);
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
std::vector<MTLFC> func_consts = {
{&inverse, MTL::DataType::DataTypeBool, 0},
{&power_of_2, MTL::DataType::DataTypeBool, 1},
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
for (int i = 0; i < steps.size(); i++) {
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
}
// Get the kernel
auto in_type = in.dtype() == float32 ? "float" : "float2";
auto out_type = out.dtype() == float32 ? "float" : "float2";
std::string hash_name;
std::string kname;
kname.reserve(64);
hash_name.reserve(64);
concatenate(
kname, "bluestein_fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
auto template_def = get_template_definition(
kname, "bluestein_fft", tg_mem_size, in_type, out_type);
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
// Get the bluestein constants
auto [w_k, w_q] =
compute_bluestein_constants(plan.size(), plan.bluestein_size());
d.add_temporary(w_k, s.index);
d.add_temporary(w_q, s.index);
// Launch it
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_input_array(w_q, 2);
compute_encoder.set_input_array(w_k, 3);
compute_encoder.set_bytes(plan.size(), 4);
compute_encoder.set_bytes(n, 5);
compute_encoder.set_bytes(total_batch_size, 6);
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void fft_multi_upload_bluestein(
const FFTPlan& plan,
const array& in_,
array& out,
size_t axis,
bool inverse,
bool real,
metal::Device& d,
const Stream& s) {
// Get Bluestein's constants using the CPU (this is done in the submission
// thread which is pretty bad).
auto [w_k, w_q] =
compute_bluestein_constants(plan.size(), plan.bluestein_size());
d.add_temporary(w_k, s.index);
d.add_temporary(w_q, s.index);
// Prepare the input
auto in_shape = inverse ? out.shape() : in_.shape();
array in(std::move(in_shape), complex64, nullptr, {});
if (real && !inverse) {
copy_gpu(
in_,
in,
in_.flags().row_contiguous ? CopyType::Vector : CopyType::General,
s);
d.add_temporary(in, s.index);
} else if (real && inverse) {
int back_offset = plan.size() % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape();
slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {});
Shape rstarts(in.ndim(), 0);
Shape rstrides(in.ndim(), 1);
rstarts[axis] = in.shape(axis) - back_offset;
rstrides[axis] = -1;
unary_op_gpu({in_}, conj_temp, "Conjugate", s);
slice_gpu(in_, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, in, (int)axis, s);
d.add_temporary(conj_temp, s.index);
} else if (inverse) {
unary_op_gpu({in_}, in, "Conjugate", s);
d.add_temporary(in, s.index);
} else {
in.copy_shared_buffer(in_);
}
// Multiply with
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast(in.shape(), complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
array x(in.shape(), complex64, nullptr, {});
binary_op_gpu({in, w_k_broadcast}, x, "Multiply", s);
d.add_temporary(x, s.index);
// Pad
auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_size();
array padded_x(padded_shape, complex64, nullptr, {});
auto zero = array(complex64_t{0.0f, 0.0f});
pad_gpu(x, zero, padded_x, {(int)axis}, {0}, s);
d.add_temporary(zero, s.index);
d.add_temporary(padded_x, s.index);
// First fft
}
void fft_op_inplace(
const array& in,
array& out,
@ -1164,6 +1314,8 @@ void fft_op_inplace(
return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
case FFTPlan::SMALL_FOUR_STEP:
return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s);
case FFTPlan::BLUESTEIN:
return fft_bluestein(plan, in, out, axis, inverse, real, d, s);
case FFTPlan::UNSUPPORTED: {
std::string msg;
concatenate(msg, "FFT of size ", plan.size(), " not supported");