From 2a41caa00e913a5d4e4364b233f4fd2478f9b546 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 8 May 2025 13:15:20 -0700 Subject: [PATCH] Add single kernel bluestein --- mlx/backend/metal/fft.cpp | 164 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 158 insertions(+), 6 deletions(-) diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 5de0217d3..123917e6e 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -122,6 +122,7 @@ class FFTPlan { STOCKHAM, RADER, BLUESTEIN, + MULTIUPLOAD_BLUESTEIN, SMALL_FOUR_STEP, LARGE_FOUR_STEP }; @@ -132,7 +133,7 @@ class FFTPlan { type_ = NOOP; } - // Four step fft + // Too large for Stockham so do four step fft for powers of 2 else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) { if (n <= 1 << 20) { type_ = SMALL_FOUR_STEP; @@ -145,9 +146,9 @@ class FFTPlan { } } - // Bluestein fft + // Too large and not power of 2 so do multi-upload Bluestein fft else if (n > MAX_STOCKHAM_FFT_SIZE) { - type_ = BLUESTEIN; + type_ = MULTIUPLOAD_BLUESTEIN; bluestein_n_ = next_fast_n(2 * n - 1); } @@ -157,9 +158,15 @@ class FFTPlan { steps_ = steps; } - // throw for now but we have rader and bluestein to do - else { - type_ = UNSUPPORTED; + // Add rader but for now simply fall back to bluestein when stockham not + // posssible + else if (n > MAX_BLUESTEIN_FFT_SIZE) { + type_ = MULTIUPLOAD_BLUESTEIN; + bluestein_n_ = next_fast_n(2 * n - 1); + } else { + type_ = BLUESTEIN; + bluestein_n_ = next_fast_n(2 * n - 1); + steps_ = stockham_decompose(bluestein_n_); } } @@ -191,6 +198,10 @@ class FFTPlan { return steps2_; } + int bluestein_size() const { + return bluestein_n_; + } + private: int n_; FFTType type_; @@ -1144,6 +1155,145 @@ void fft_four_step_inplace( } } +void fft_bluestein( + const FFTPlan& plan, + const array& in_, + array& out, + size_t axis, + bool inverse, + bool real, + metal::Device& d, + const Stream& s) { + // Prepare the input and output arrays such that `axis` has stride 1. + // Possibly copy the input but never the output as it doesn't have anything + // useful in it yet. + array in = ensure_fastest_moving_axis(in_, axis, d, s); + prepare_output_array(in, out, axis); + + // Prepare the arguments for bluestein fft + int n = plan.bluestein_size(); + bool power_of_2 = true; + int total_batch_size = out.dtype() == float32 ? out.size() / plan.size() + : in.size() / plan.size(); + auto& steps = plan.steps(); + int elems_per_thread = compute_elems_per_thread(n, steps); + int threads_per_fft = ceildiv(n, elems_per_thread); + int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1); + int tg_mem_size = next_power_of_2(tg_batch_size * n); + int batch_size = ceildiv(total_batch_size, tg_batch_size); + batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once + std::vector func_consts = { + {&inverse, MTL::DataType::DataTypeBool, 0}, + {&power_of_2, MTL::DataType::DataTypeBool, 1}, + {&elems_per_thread, MTL::DataType::DataTypeInt, 2}}; + for (int i = 0; i < steps.size(); i++) { + func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i); + } + + // Get the kernel + auto in_type = in.dtype() == float32 ? "float" : "float2"; + auto out_type = out.dtype() == float32 ? "float" : "float2"; + std::string hash_name; + std::string kname; + kname.reserve(64); + hash_name.reserve(64); + concatenate( + kname, "bluestein_fft_mem_", tg_mem_size, "_", in_type, "_", out_type); + concatenate(hash_name, kname, "_n", n, "_inv_", inverse); + auto template_def = get_template_definition( + kname, "bluestein_fft", tg_mem_size, in_type, out_type); + auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def); + + // Get the bluestein constants + auto [w_k, w_q] = + compute_bluestein_constants(plan.size(), plan.bluestein_size()); + d.add_temporary(w_k, s.index); + d.add_temporary(w_q, s.index); + + // Launch it + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_input_array(w_q, 2); + compute_encoder.set_input_array(w_k, 3); + compute_encoder.set_bytes(plan.size(), 4); + compute_encoder.set_bytes(n, 5); + compute_encoder.set_bytes(total_batch_size, 6); + + MTL::Size group_dims(1, tg_batch_size, threads_per_fft); + MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft); + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + +void fft_multi_upload_bluestein( + const FFTPlan& plan, + const array& in_, + array& out, + size_t axis, + bool inverse, + bool real, + metal::Device& d, + const Stream& s) { + // Get Bluestein's constants using the CPU (this is done in the submission + // thread which is pretty bad). + auto [w_k, w_q] = + compute_bluestein_constants(plan.size(), plan.bluestein_size()); + d.add_temporary(w_k, s.index); + d.add_temporary(w_q, s.index); + + // Prepare the input + auto in_shape = inverse ? out.shape() : in_.shape(); + array in(std::move(in_shape), complex64, nullptr, {}); + if (real && !inverse) { + copy_gpu( + in_, + in, + in_.flags().row_contiguous ? CopyType::Vector : CopyType::General, + s); + d.add_temporary(in, s.index); + } else if (real && inverse) { + int back_offset = plan.size() % 2 == 0 ? 2 : 1; + auto slice_shape = in.shape(); + slice_shape[axis] -= back_offset; + array slice_temp(slice_shape, complex64, nullptr, {}); + array conj_temp(in.shape(), complex64, nullptr, {}); + Shape rstarts(in.ndim(), 0); + Shape rstrides(in.ndim(), 1); + rstarts[axis] = in.shape(axis) - back_offset; + rstrides[axis] = -1; + unary_op_gpu({in_}, conj_temp, "Conjugate", s); + slice_gpu(in_, slice_temp, rstarts, rstrides, s); + concatenate_gpu({conj_temp, slice_temp}, in, (int)axis, s); + d.add_temporary(conj_temp, s.index); + } else if (inverse) { + unary_op_gpu({in_}, in, "Conjugate", s); + d.add_temporary(in, s.index); + } else { + in.copy_shared_buffer(in_); + } + + // Multiply with + Strides b_strides(in.ndim(), 0); + b_strides[axis] = 1; + array w_k_broadcast(in.shape(), complex64, nullptr, {}); + w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size()); + array x(in.shape(), complex64, nullptr, {}); + binary_op_gpu({in, w_k_broadcast}, x, "Multiply", s); + d.add_temporary(x, s.index); + + // Pad + auto padded_shape = out.shape(); + padded_shape[axis] = plan.bluestein_size(); + array padded_x(padded_shape, complex64, nullptr, {}); + auto zero = array(complex64_t{0.0f, 0.0f}); + pad_gpu(x, zero, padded_x, {(int)axis}, {0}, s); + d.add_temporary(zero, s.index); + d.add_temporary(padded_x, s.index); + + // First fft +} + void fft_op_inplace( const array& in, array& out, @@ -1164,6 +1314,8 @@ void fft_op_inplace( return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s); case FFTPlan::SMALL_FOUR_STEP: return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s); + case FFTPlan::BLUESTEIN: + return fft_bluestein(plan, in, out, axis, inverse, real, d, s); case FFTPlan::UNSUPPORTED: { std::string msg; concatenate(msg, "FFT of size ", plan.size(), " not supported");