diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 536c6f6f8..5de0217d3 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -117,6 +117,7 @@ struct OldFFTPlan { class FFTPlan { public: enum FFTType { + UNSUPPORTED, NOOP, STOCKHAM, RADER, @@ -137,6 +138,8 @@ class FFTPlan { type_ = SMALL_FOUR_STEP; n2_ = n > 65536 ? 1024 : 64; n1_ = n / n2_; + steps1_ = stockham_decompose(n1_); + steps2_ = stockham_decompose(n2_); } else { type_ = LARGE_FOUR_STEP; } @@ -156,6 +159,7 @@ class FFTPlan { // throw for now but we have rader and bluestein to do else { + type_ = UNSUPPORTED; } } @@ -171,12 +175,30 @@ class FFTPlan { return steps_; } + int first_size() const { + return n1_; + } + + const std::vector& first_steps() const { + return steps1_; + } + + int second_size() const { + return n2_; + } + + const std::vector& second_steps() const { + return steps2_; + } + private: int n_; FFTType type_; std::vector steps_; int n1_; + std::vector steps1_; int n2_; + std::vector steps2_; int bluestein_n_; }; @@ -997,11 +1019,11 @@ void fft_stockham_inplace( out.dtype() == float32 ? out.size() / n : in.size() / n; auto& steps = plan.steps(); int elems_per_thread = compute_elems_per_thread(n, steps); - int threads_per_fft = ceildiv(plan.size(), elems_per_thread); - int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / plan.size(), 1); - int tg_mem_size = next_power_of_2(tg_batch_size * plan.size()); + int threads_per_fft = ceildiv(n, elems_per_thread); + int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1); + int tg_mem_size = next_power_of_2(tg_batch_size * n); int batch_size = ceildiv(total_batch_size, tg_batch_size); - batch_size = real ? ceildiv(batch_size, 2) : batch_size; + batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once std::vector func_consts = { {&inverse, MTL::DataType::DataTypeBool, 0}, {&power_of_2, MTL::DataType::DataTypeBool, 1}, @@ -1029,13 +1051,99 @@ void fft_stockham_inplace( compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(n, 2); - compute_encoder.set_bytes(batch_size, 3); + compute_encoder.set_bytes(total_batch_size, 3); MTL::Size group_dims(1, tg_batch_size, threads_per_fft); MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft); compute_encoder.dispatch_threads(grid_dims, group_dims); } +void fft_four_step_inplace( + const FFTPlan& plan, + const array& in_, + array& out, + size_t axis, + bool inverse, + bool real, + metal::Device& d, + const Stream& s) { + // Prepare the input and output arrays such that `axis` has stride 1. + // Possibly copy the input but never the output as it doesn't have anything + // useful in it yet. + array in = ensure_fastest_moving_axis(in_, axis, d, s); + prepare_output_array(in, out, axis); + + // Also prepare the intermediate array for the four-step fft which is + // implemented with 2 kernel calls. + array intermediate( + (real && inverse) ? out.shape() : in.shape(), complex64, nullptr, {}); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); + prepare_output_array(in, intermediate, axis); + d.add_temporary(intermediate, s.index); + + // Make the two calls + for (int step = 0; step < 2; step++) { + // Create the parameters + int n1 = plan.first_size(); + int n2 = plan.second_size(); + int n = (step == 0) ? n1 : n2; + bool power_of_2 = true; + int total_batch_size = + out.dtype() == float32 ? out.size() / n : in.size() / n; + auto& steps = (step == 0) ? plan.first_steps() : plan.second_steps(); + int elems_per_thread = compute_elems_per_thread(n, steps); + int threads_per_fft = ceildiv(n, elems_per_thread); + int tg_batch_size = + std::max(MIN_THREADGROUP_MEM_SIZE / n, MIN_COALESCE_WIDTH); + int tg_mem_size = next_power_of_2(tg_batch_size * n); + int batch_size = ceildiv(total_batch_size, tg_batch_size); + std::vector func_consts = { + {&inverse, MTL::DataType::DataTypeBool, 0}, + {&power_of_2, MTL::DataType::DataTypeBool, 1}, + {&elems_per_thread, MTL::DataType::DataTypeInt, 2}}; + for (int i = 0; i < steps.size(); i++) { + func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i); + } + + // Get the kernel + auto in_type = in.dtype() == float32 ? "float" : "float2"; + auto out_type = out.dtype() == float32 ? "float" : "float2"; + std::string hash_name; + std::string kname; + kname.reserve(64); + hash_name.reserve(64); + concatenate( + kname, + "four_step_mem_", + tg_mem_size, + "_", + in_type, + "_", + out_type, + "_", + step, + (real ? "_true" : "_false")); + concatenate(hash_name, kname, "_n", n, "_inv_", inverse); + auto template_def = get_template_definition( + kname, "four_step_fft", tg_mem_size, in_type, out_type, step, real); + auto kernel = + get_fft_kernel(d, kname, hash_name, func_consts, template_def); + + // Launch it + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array((step == 0) ? in : intermediate, 0); + compute_encoder.set_output_array((step == 0) ? intermediate : out, 1); + compute_encoder.set_bytes(n1, 2); + compute_encoder.set_bytes(n2, 3); + compute_encoder.set_bytes(total_batch_size, 4); + + MTL::Size group_dims(1, tg_batch_size, threads_per_fft); + MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } +} + void fft_op_inplace( const array& in, array& out, @@ -1053,10 +1161,17 @@ void fft_op_inplace( std::cout << "--------------> 1-size FFT <-----------------" << std::endl; break; case FFTPlan::STOCKHAM: - fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s); - break; + return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s); + case FFTPlan::SMALL_FOUR_STEP: + return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s); + case FFTPlan::UNSUPPORTED: { + std::string msg; + concatenate(msg, "FFT of size ", plan.size(), " not supported"); + throw std::runtime_error(msg); + } default: std::cout << "----- NYI ----" << std::endl; + break; } }