From 83762691bac00601cd12f4556c0b11511c8d8673 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 8 May 2025 14:14:59 -0700 Subject: [PATCH] Fix four step fft --- mlx/backend/metal/fft.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 123917e6e..0e4b27301 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -1117,8 +1117,11 @@ void fft_four_step_inplace( } // Get the kernel - auto in_type = in.dtype() == float32 ? "float" : "float2"; - auto out_type = out.dtype() == float32 ? "float" : "float2"; + auto to_type = [](const array& x) { + return x.dtype() == float32 ? "float" : "float2"; + }; + auto in_type = step == 0 ? to_type(in) : to_type(intermediate); + auto out_type = step == 0 ? to_type(intermediate) : to_type(out); std::string hash_name; std::string kname; kname.reserve(64);