mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Feature complete Metal FFT (#1102)
* feature complete metal fft * fix contiguity bug * jit fft * simplify rader/bluestein constant computation * remove kernel/utils.h dep * remove bf16.h dep * format --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -7,8 +7,6 @@
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test fft basics") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
array x(1.0);
|
||||
CHECK_THROWS(fft::fft(x));
|
||||
CHECK_THROWS(fft::ifft(x));
|
||||
@@ -94,13 +92,9 @@ TEST_CASE("test fft basics") {
|
||||
CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(y, 1), x).item<bool>());
|
||||
}
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test real ffts") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto x = array({1.0});
|
||||
auto y = fft::rfft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
@@ -124,14 +118,9 @@ TEST_CASE("test real ffts") {
|
||||
CHECK_EQ(y.size(), 2);
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>());
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fftn") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto x = zeros({5, 5, 5});
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument);
|
||||
@@ -204,8 +193,6 @@ TEST_CASE("test fftn") {
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
}
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fft with provided shape") {
|
||||
@@ -234,9 +221,6 @@ TEST_CASE("test fft with provided shape") {
|
||||
}
|
||||
|
||||
TEST_CASE("test fft vmap") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto x = reshape(arange(8), {2, 4});
|
||||
auto y = vmap(fft_fn)(x);
|
||||
@@ -252,14 +236,9 @@ TEST_CASE("test fft vmap") {
|
||||
|
||||
y = vmap(rfft_fn, 1, 1)(x);
|
||||
CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>());
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fft grads") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
// Regular
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto cotangent = astype(arange(10), complex64);
|
||||
@@ -328,6 +307,4 @@ TEST_CASE("test fft grads") {
|
||||
zeros({5, 8}))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
Reference in New Issue
Block a user