mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix fft bug (#2062)
This commit is contained in:
parent
ddaa4b7dcb
commit
ef7ece9851
@ -356,20 +356,14 @@ void multi_upload_bluestein_fft(
|
|||||||
bool inverse,
|
bool inverse,
|
||||||
bool real,
|
bool real,
|
||||||
FFTPlan& plan,
|
FFTPlan& plan,
|
||||||
std::vector<array> copies,
|
std::vector<array>& copies,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||||
// algorithm
|
// algorithm
|
||||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||||
|
copies.push_back(w_k);
|
||||||
// Broadcast w_q and w_k to the batch size
|
copies.push_back(w_q);
|
||||||
Strides b_strides(in.ndim(), 0);
|
|
||||||
b_strides[axis] = 1;
|
|
||||||
array w_k_broadcast({}, complex64, nullptr, {});
|
|
||||||
array w_q_broadcast({}, complex64, nullptr, {});
|
|
||||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
|
||||||
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
|
||||||
|
|
||||||
auto temp_shape = inverse ? out.shape() : in.shape();
|
auto temp_shape = inverse ? out.shape() : in.shape();
|
||||||
array temp(temp_shape, complex64, nullptr, {});
|
array temp(temp_shape, complex64, nullptr, {});
|
||||||
@ -378,13 +372,13 @@ void multi_upload_bluestein_fft(
|
|||||||
if (real && !inverse) {
|
if (real && !inverse) {
|
||||||
// Convert float32->complex64
|
// Convert float32->complex64
|
||||||
copy_gpu(in, temp, CopyType::General, s);
|
copy_gpu(in, temp, CopyType::General, s);
|
||||||
|
copies.push_back(temp);
|
||||||
} else if (real && inverse) {
|
} else if (real && inverse) {
|
||||||
int back_offset = n % 2 == 0 ? 2 : 1;
|
int back_offset = n % 2 == 0 ? 2 : 1;
|
||||||
auto slice_shape = in.shape();
|
auto slice_shape = in.shape();
|
||||||
slice_shape[axis] -= back_offset;
|
slice_shape[axis] -= back_offset;
|
||||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||||
copies.push_back(slice_temp);
|
|
||||||
copies.push_back(conj_temp);
|
copies.push_back(conj_temp);
|
||||||
|
|
||||||
Shape rstarts(in.ndim(), 0);
|
Shape rstarts(in.ndim(), 0);
|
||||||
@ -394,19 +388,28 @@ void multi_upload_bluestein_fft(
|
|||||||
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
||||||
slice_gpu(in, slice_temp, rstarts, rstrides, s);
|
slice_gpu(in, slice_temp, rstarts, rstrides, s);
|
||||||
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
|
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
|
||||||
|
copies.push_back(temp);
|
||||||
} else if (inverse) {
|
} else if (inverse) {
|
||||||
unary_op_gpu({in}, temp, "Conjugate", s);
|
unary_op_gpu({in}, temp, "Conjugate", s);
|
||||||
|
copies.push_back(temp);
|
||||||
} else {
|
} else {
|
||||||
temp.copy_shared_buffer(in);
|
temp.copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Strides b_strides(in.ndim(), 0);
|
||||||
|
b_strides[axis] = 1;
|
||||||
|
array w_k_broadcast(temp.shape(), complex64, nullptr, {});
|
||||||
|
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||||
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
|
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||||
|
|
||||||
std::vector<std::pair<int, int>> pads;
|
std::vector<std::pair<int, int>> pads;
|
||||||
auto padded_shape = out.shape();
|
auto padded_shape = out.shape();
|
||||||
padded_shape[axis] = plan.bluestein_n;
|
padded_shape[axis] = plan.bluestein_n;
|
||||||
array pad_temp(padded_shape, complex64, nullptr, {});
|
array pad_temp(padded_shape, complex64, nullptr, {});
|
||||||
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
|
auto zero = array(complex64_t{0.0f, 0.0f});
|
||||||
|
copies.push_back(zero);
|
||||||
|
pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
|
||||||
|
copies.push_back(pad_temp);
|
||||||
|
|
||||||
array pad_temp1(padded_shape, complex64, nullptr, {});
|
array pad_temp1(padded_shape, complex64, nullptr, {});
|
||||||
fft_op(
|
fft_op(
|
||||||
@ -418,7 +421,10 @@ void multi_upload_bluestein_fft(
|
|||||||
FourStepParams(),
|
FourStepParams(),
|
||||||
/*inplace=*/false,
|
/*inplace=*/false,
|
||||||
s);
|
s);
|
||||||
|
copies.push_back(pad_temp1);
|
||||||
|
|
||||||
|
array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
|
||||||
|
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
||||||
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
|
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
|
||||||
|
|
||||||
fft_op(
|
fft_op(
|
||||||
@ -435,9 +441,11 @@ void multi_upload_bluestein_fft(
|
|||||||
Shape starts(in.ndim(), 0);
|
Shape starts(in.ndim(), 0);
|
||||||
Shape strides(in.ndim(), 1);
|
Shape strides(in.ndim(), 1);
|
||||||
starts[axis] = plan.bluestein_n - offset - n;
|
starts[axis] = plan.bluestein_n - offset - n;
|
||||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
|
||||||
|
|
||||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
|
array temp2(temp_shape, complex64, nullptr, {});
|
||||||
|
slice_gpu(pad_temp1, temp2, starts, strides, s);
|
||||||
|
|
||||||
|
binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
|
||||||
|
|
||||||
if (real && !inverse) {
|
if (real && !inverse) {
|
||||||
Shape rstarts(in.ndim(), 0);
|
Shape rstarts(in.ndim(), 0);
|
||||||
@ -449,26 +457,21 @@ void multi_upload_bluestein_fft(
|
|||||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||||
copies.push_back(temp_float);
|
copies.push_back(temp_float);
|
||||||
copies.push_back(inv_n);
|
copies.push_back(inv_n);
|
||||||
|
copies.push_back(temp1);
|
||||||
|
|
||||||
copy_gpu(temp1, temp_float, CopyType::General, s);
|
copy_gpu(temp1, temp_float, CopyType::General, s);
|
||||||
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
|
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
|
||||||
} else if (inverse) {
|
} else if (inverse) {
|
||||||
auto inv_n = array({1.0f / n}, {1}, complex64);
|
auto inv_n = array({1.0f / n}, {1}, complex64);
|
||||||
unary_op_gpu({temp1}, temp, "Conjugate", s);
|
array temp3(temp_shape, complex64, nullptr, {});
|
||||||
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
|
unary_op_gpu({temp1}, temp3, "Conjugate", s);
|
||||||
|
binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
|
||||||
copies.push_back(inv_n);
|
copies.push_back(inv_n);
|
||||||
|
copies.push_back(temp1);
|
||||||
|
copies.push_back(temp3);
|
||||||
} else {
|
} else {
|
||||||
out.copy_shared_buffer(temp1);
|
out.copy_shared_buffer(temp1);
|
||||||
}
|
}
|
||||||
|
|
||||||
copies.push_back(w_k);
|
|
||||||
copies.push_back(w_q);
|
|
||||||
copies.push_back(w_k_broadcast);
|
|
||||||
copies.push_back(w_q_broadcast);
|
|
||||||
copies.push_back(temp);
|
|
||||||
copies.push_back(temp1);
|
|
||||||
copies.push_back(pad_temp);
|
|
||||||
copies.push_back(pad_temp1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void four_step_fft(
|
void four_step_fft(
|
||||||
@ -478,8 +481,9 @@ void four_step_fft(
|
|||||||
bool inverse,
|
bool inverse,
|
||||||
bool real,
|
bool real,
|
||||||
FFTPlan& plan,
|
FFTPlan& plan,
|
||||||
std::vector<array> copies,
|
std::vector<array>& copies,
|
||||||
const Stream& s) {
|
const Stream& s,
|
||||||
|
bool in_place) {
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
if (plan.bluestein_n == -1) {
|
if (plan.bluestein_n == -1) {
|
||||||
@ -492,7 +496,14 @@ void four_step_fft(
|
|||||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||||
four_step_params.first_step = false;
|
four_step_params.first_step = false;
|
||||||
fft_op(
|
fft_op(
|
||||||
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
temp,
|
||||||
|
out,
|
||||||
|
axis,
|
||||||
|
inverse,
|
||||||
|
real,
|
||||||
|
four_step_params,
|
||||||
|
/*inplace=*/in_place,
|
||||||
|
s);
|
||||||
copies.push_back(temp);
|
copies.push_back(temp);
|
||||||
} else {
|
} else {
|
||||||
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
|
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||||
@ -574,7 +585,7 @@ void fft_op(
|
|||||||
|
|
||||||
auto plan = plan_fft(n);
|
auto plan = plan_fft(n);
|
||||||
if (plan.four_step) {
|
if (plan.four_step) {
|
||||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -483,4 +483,4 @@ template <
|
|||||||
perform_fft(fft_idx, &p, m, n, buf);
|
perform_fft(fft_idx, &p, m, n, buf);
|
||||||
|
|
||||||
read_writer.write_strided(stride, overall_n);
|
read_writer.write_strided(stride, overall_n);
|
||||||
}
|
}
|
||||||
|
@ -182,6 +182,18 @@ class TestFFT(mlx_tests.MLXTestCase):
|
|||||||
out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1))))
|
out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1))))
|
||||||
np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5)
|
np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
def test_fft_into_ifft(self):
|
||||||
|
n_fft = 8193
|
||||||
|
mx.random.seed(0)
|
||||||
|
|
||||||
|
segment = mx.random.normal(shape=[1, n_fft]) + 1j * mx.random.normal(
|
||||||
|
shape=(1, n_fft)
|
||||||
|
)
|
||||||
|
segment = mx.fft.fft(segment, n=n_fft)
|
||||||
|
r = mx.fft.ifft(segment, n=n_fft)
|
||||||
|
r_np = np.fft.ifft(segment, n=n_fft)
|
||||||
|
self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user