From da98e8bce87d3a362bd215c84f0f11f397f47944 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 6 May 2025 21:46:21 -0700 Subject: [PATCH] Refactored stockham --- mlx/backend/metal/fft.cpp | 222 +++++++++++++++++++++++++++++++++++--- 1 file changed, 205 insertions(+), 17 deletions(-) diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index a34da87ff..536c6f6f8 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -51,6 +51,35 @@ std::vector prime_factors(int n) { return factors; } +int next_fast_n(int n) { + return next_power_of_2(n); +} + +std::vector stockham_decompose(int n) { + auto radices = supported_radices(); + std::vector steps(radices.size(), 0); + int orig_n = n; + + for (int i = 0; i < radices.size(); i++) { + int radix = radices[i]; + + // Manually tuned radices for powers of 2 + if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) { + continue; + } + + while (n % radix == 0) { + steps[i] += 1; + n /= radix; + if (n == 1) { + return steps; + } + } + } + + return {}; +} + struct FourStepParams { bool required = false; bool first_step = true; @@ -70,7 +99,7 @@ void fft_op( metal::Device& d, const Stream& s); -struct FFTPlan { +struct OldFFTPlan { int n = 0; // Number of steps for each radix in the Stockham decomposition std::vector stockham; @@ -85,9 +114,71 @@ struct FFTPlan { int n2 = 0; }; -int next_fast_n(int n) { - return next_power_of_2(n); -} +class FFTPlan { + public: + enum FFTType { + NOOP, + STOCKHAM, + RADER, + BLUESTEIN, + SMALL_FOUR_STEP, + LARGE_FOUR_STEP + }; + + FFTPlan(int n) : n_(n) { + // NOOP + if (n == 1) { + type_ = NOOP; + } + + // Four step fft + else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) { + if (n <= 1 << 20) { + type_ = SMALL_FOUR_STEP; + n2_ = n > 65536 ? 1024 : 64; + n1_ = n / n2_; + } else { + type_ = LARGE_FOUR_STEP; + } + } + + // Bluestein fft + else if (n > MAX_STOCKHAM_FFT_SIZE) { + type_ = BLUESTEIN; + bluestein_n_ = next_fast_n(2 * n - 1); + } + + // Stockham fft + else if (auto steps = stockham_decompose(n); steps.size() > 0) { + type_ = STOCKHAM; + steps_ = steps; + } + + // throw for now but we have rader and bluestein to do + else { + } + } + + FFTType type() const { + return type_; + } + + int size() const { + return n_; + } + + const std::vector& steps() const { + return steps_; + } + + private: + int n_; + FFTType type_; + std::vector steps_; + int n1_; + int n2_; + int bluestein_n_; +}; std::vector plan_stockham_fft(int n) { auto radices = supported_radices(); @@ -113,10 +204,10 @@ std::vector plan_stockham_fft(int n) { throw std::runtime_error("Unplannable"); } -FFTPlan plan_fft(int n) { +OldFFTPlan plan_fft(int n) { auto radices = supported_radices(); - FFTPlan plan; + OldFFTPlan plan; plan.n = n; plan.rader = std::vector(radices.size(), 0); @@ -176,7 +267,7 @@ FFTPlan plan_fft(int n) { return plan; } -int compute_elems_per_thread(FFTPlan plan) { +int compute_elems_per_thread(OldFFTPlan plan) { // Heuristics for selecting an efficient number // of threads to use for a particular mixed-radix FFT. auto n = plan.n; @@ -359,7 +450,7 @@ void multi_upload_bluestein_fft( size_t axis, bool inverse, bool real, - FFTPlan& plan, + OldFFTPlan& plan, std::vector& copies, const Stream& s) { auto& d = metal::device(s.device); @@ -488,7 +579,7 @@ void four_step_fft( size_t axis, bool inverse, bool real, - FFTPlan& plan, + OldFFTPlan& plan, std::vector& copies, const Stream& s, bool in_place) { @@ -771,6 +862,51 @@ void fft_op( d.add_temporaries(std::move(copies), s.index); } +inline int compute_elems_per_thread(int n, const std::vector& steps) { + auto radices = supported_radices(); + std::set used_radices; + for (int i = 0; i < steps.size(); i++) { + if (steps[i] > 0) { + used_radices.insert(radices[i % radices.size()]); + } + } + + // Manual tuning for 7/11/13 + if (used_radices.find(7) != used_radices.end() && + (used_radices.find(11) != used_radices.end() || + used_radices.find(13) != used_radices.end())) { + return 7; + } else if ( + used_radices.find(11) != used_radices.end() && + used_radices.find(13) != used_radices.end()) { + return 11; + } + + // TODO(alexbarron) Some really weird stuff is going on + // for certain `elems_per_thread` on large composite n. + // Possibly a compiler issue? + if (n == 3159) + return 13; + if (n == 3645) + return 5; + if (n == 3969) + return 7; + if (n == 1982) + return 5; + + if (used_radices.size() == 1) { + return *(used_radices.begin()); + } + if (used_radices.size() == 2 && + (used_radices.find(11) != used_radices.end() || + used_radices.find(13) != used_radices.end())) { + return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2; + } + + // In all other cases use the second smallest radix. + return *(++used_radices.begin()); +} + inline array ensure_fastest_moving_axis( const array& x, int axis, @@ -840,6 +976,7 @@ inline void prepare_output_array(const array& in, array& out, int axis) { } void fft_stockham_inplace( + const FFTPlan& plan, const array& in_, array& out, size_t axis, @@ -847,8 +984,56 @@ void fft_stockham_inplace( bool real, metal::Device& d, const Stream& s) { + // Prepare the input and output arrays such that `axis` has stride 1. + // Possibly copy the input but never the output as it doesn't have anything + // useful in it yet. array in = ensure_fastest_moving_axis(in_, axis, d, s); prepare_output_array(in, out, axis); + + // Prepare the arguments for stockham fft + int n = plan.size(); + bool power_of_2 = is_power_of_2(n); + int total_batch_size = + out.dtype() == float32 ? out.size() / n : in.size() / n; + auto& steps = plan.steps(); + int elems_per_thread = compute_elems_per_thread(n, steps); + int threads_per_fft = ceildiv(plan.size(), elems_per_thread); + int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / plan.size(), 1); + int tg_mem_size = next_power_of_2(tg_batch_size * plan.size()); + int batch_size = ceildiv(total_batch_size, tg_batch_size); + batch_size = real ? ceildiv(batch_size, 2) : batch_size; + std::vector func_consts = { + {&inverse, MTL::DataType::DataTypeBool, 0}, + {&power_of_2, MTL::DataType::DataTypeBool, 1}, + {&elems_per_thread, MTL::DataType::DataTypeInt, 2}}; + for (int i = 0; i < steps.size(); i++) { + func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i); + } + + // Get the kernel + auto in_type = in.dtype() == float32 ? "float" : "float2"; + auto out_type = out.dtype() == float32 ? "float" : "float2"; + std::string hash_name; + std::string kname; + kname.reserve(64); + hash_name.reserve(64); + concatenate(kname, "fft_mem_", tg_mem_size, "_", in_type, "_", out_type); + concatenate(hash_name, kname, "_n", n, "_inv_", inverse); + auto template_def = + get_template_definition(kname, "fft", tg_mem_size, in_type, out_type); + auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def); + + // Launch it + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(n, 2); + compute_encoder.set_bytes(batch_size, 3); + + MTL::Size group_dims(1, tg_batch_size, threads_per_fft); + MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void fft_op_inplace( @@ -860,15 +1045,18 @@ void fft_op_inplace( metal::Device& d, const Stream& s) { // Get the FFT size and plan it - size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis); - auto plan = plan_fft(n); - if (n == 1) { - std::cout << "--------------> 1-size FFT <-----------------" << std::endl; - } + auto plan = + FFTPlan(out.dtype() == float32 ? out.shape(axis) : in.shape(axis)); - if (plan.four_step && plan.bluestein_n < 0) { - // four_step_fft(in, out, axis, inverse, real, plan, inplace, d, s); - return; + switch (plan.type()) { + case FFTPlan::NOOP: + std::cout << "--------------> 1-size FFT <-----------------" << std::endl; + break; + case FFTPlan::STOCKHAM: + fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s); + break; + default: + std::cout << "----- NYI ----" << std::endl; } }