diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 95678279e..153c62c02 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -356,20 +356,14 @@ void multi_upload_bluestein_fft( bool inverse, bool real, FFTPlan& plan, - std::vector copies, + std::vector& copies, const Stream& s) { // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's // algorithm int n = inverse ? out.shape(axis) : in.shape(axis); auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); - - // Broadcast w_q and w_k to the batch size - Strides b_strides(in.ndim(), 0); - b_strides[axis] = 1; - array w_k_broadcast({}, complex64, nullptr, {}); - array w_q_broadcast({}, complex64, nullptr, {}); - w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size()); - w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size()); + copies.push_back(w_k); + copies.push_back(w_q); auto temp_shape = inverse ? out.shape() : in.shape(); array temp(temp_shape, complex64, nullptr, {}); @@ -378,13 +372,13 @@ void multi_upload_bluestein_fft( if (real && !inverse) { // Convert float32->complex64 copy_gpu(in, temp, CopyType::General, s); + copies.push_back(temp); } else if (real && inverse) { int back_offset = n % 2 == 0 ? 2 : 1; auto slice_shape = in.shape(); slice_shape[axis] -= back_offset; array slice_temp(slice_shape, complex64, nullptr, {}); array conj_temp(in.shape(), complex64, nullptr, {}); - copies.push_back(slice_temp); copies.push_back(conj_temp); Shape rstarts(in.ndim(), 0); @@ -394,19 +388,28 @@ void multi_upload_bluestein_fft( unary_op_gpu({in}, conj_temp, "Conjugate", s); slice_gpu(in, slice_temp, rstarts, rstrides, s); concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s); + copies.push_back(temp); } else if (inverse) { unary_op_gpu({in}, temp, "Conjugate", s); + copies.push_back(temp); } else { temp.copy_shared_buffer(in); } + Strides b_strides(in.ndim(), 0); + b_strides[axis] = 1; + array w_k_broadcast(temp.shape(), complex64, nullptr, {}); + w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size()); binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s); std::vector> pads; auto padded_shape = out.shape(); padded_shape[axis] = plan.bluestein_n; array pad_temp(padded_shape, complex64, nullptr, {}); - pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s); + auto zero = array(complex64_t{0.0f, 0.0f}); + copies.push_back(zero); + pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s); + copies.push_back(pad_temp); array pad_temp1(padded_shape, complex64, nullptr, {}); fft_op( @@ -418,7 +421,10 @@ void multi_upload_bluestein_fft( FourStepParams(), /*inplace=*/false, s); + copies.push_back(pad_temp1); + array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {}); + w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size()); binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s); fft_op( @@ -435,9 +441,11 @@ void multi_upload_bluestein_fft( Shape starts(in.ndim(), 0); Shape strides(in.ndim(), 1); starts[axis] = plan.bluestein_n - offset - n; - slice_gpu(pad_temp1, temp, starts, strides, s); - binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s); + array temp2(temp_shape, complex64, nullptr, {}); + slice_gpu(pad_temp1, temp2, starts, strides, s); + + binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s); if (real && !inverse) { Shape rstarts(in.ndim(), 0); @@ -449,26 +457,21 @@ void multi_upload_bluestein_fft( array temp_float(out.shape(), out.dtype(), nullptr, {}); copies.push_back(temp_float); copies.push_back(inv_n); + copies.push_back(temp1); copy_gpu(temp1, temp_float, CopyType::General, s); binary_op_gpu({temp_float, inv_n}, out, "Multiply", s); } else if (inverse) { auto inv_n = array({1.0f / n}, {1}, complex64); - unary_op_gpu({temp1}, temp, "Conjugate", s); - binary_op_gpu({temp, inv_n}, out, "Multiply", s); + array temp3(temp_shape, complex64, nullptr, {}); + unary_op_gpu({temp1}, temp3, "Conjugate", s); + binary_op_gpu({temp3, inv_n}, out, "Multiply", s); copies.push_back(inv_n); + copies.push_back(temp1); + copies.push_back(temp3); } else { out.copy_shared_buffer(temp1); } - - copies.push_back(w_k); - copies.push_back(w_q); - copies.push_back(w_k_broadcast); - copies.push_back(w_q_broadcast); - copies.push_back(temp); - copies.push_back(temp1); - copies.push_back(pad_temp); - copies.push_back(pad_temp1); } void four_step_fft( @@ -478,8 +481,9 @@ void four_step_fft( bool inverse, bool real, FFTPlan& plan, - std::vector copies, - const Stream& s) { + std::vector& copies, + const Stream& s, + bool in_place) { auto& d = metal::device(s.device); if (plan.bluestein_n == -1) { @@ -492,7 +496,14 @@ void four_step_fft( in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s); four_step_params.first_step = false; fft_op( - temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s); + temp, + out, + axis, + inverse, + real, + four_step_params, + /*inplace=*/in_place, + s); copies.push_back(temp); } else { multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s); @@ -574,7 +585,7 @@ void fft_op( auto plan = plan_fft(n); if (plan.four_step) { - four_step_fft(in, out, axis, inverse, real, plan, copies, s); + four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace); d.add_temporaries(std::move(copies), s.index); return; } diff --git a/mlx/backend/metal/kernels/fft.h b/mlx/backend/metal/kernels/fft.h index a4869a2ac..e478a85b6 100644 --- a/mlx/backend/metal/kernels/fft.h +++ b/mlx/backend/metal/kernels/fft.h @@ -483,4 +483,4 @@ template < perform_fft(fft_idx, &p, m, n, buf); read_writer.write_strided(stride, overall_n); -} \ No newline at end of file +} diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 95f9f7a54..ec9a48f00 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -182,6 +182,18 @@ class TestFFT(mlx_tests.MLXTestCase): out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1)))) np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5) + def test_fft_into_ifft(self): + n_fft = 8193 + mx.random.seed(0) + + segment = mx.random.normal(shape=[1, n_fft]) + 1j * mx.random.normal( + shape=(1, n_fft) + ) + segment = mx.fft.fft(segment, n=n_fft) + r = mx.fft.ifft(segment, n=n_fft) + r_np = np.fft.ifft(segment, n=n_fft) + self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5)) + if __name__ == "__main__": unittest.main()