Feature complete Metal FFT (#1102)

* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-06-06 12:57:25 -07:00
committed by GitHub
parent 0e585b4409
commit 27d70c7d9d
17 changed files with 2601 additions and 367 deletions

View File

@@ -7,8 +7,6 @@
using namespace mlx::core;
TEST_CASE("test fft basics") {
auto device = default_device();
set_default_device(Device::cpu);
array x(1.0);
CHECK_THROWS(fft::fft(x));
CHECK_THROWS(fft::ifft(x));
@@ -94,13 +92,9 @@ TEST_CASE("test fft basics") {
CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>());
CHECK(array_equal(fft::ifft(y, 1), x).item<bool>());
}
set_default_device(device);
}
TEST_CASE("test real ffts") {
auto device = default_device();
set_default_device(Device::cpu);
auto x = array({1.0});
auto y = fft::rfft(x);
CHECK_EQ(y.dtype(), complex64);
@@ -124,14 +118,9 @@ TEST_CASE("test real ffts") {
CHECK_EQ(y.size(), 2);
CHECK_EQ(y.dtype(), float32);
CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>());
set_default_device(device);
}
TEST_CASE("test fftn") {
auto device = default_device();
set_default_device(Device::cpu);
auto x = zeros({5, 5, 5});
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument);
@@ -204,8 +193,6 @@ TEST_CASE("test fftn") {
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
CHECK_EQ(y.dtype(), float32);
}
set_default_device(device);
}
TEST_CASE("test fft with provided shape") {
@@ -234,9 +221,6 @@ TEST_CASE("test fft with provided shape") {
}
TEST_CASE("test fft vmap") {
auto device = default_device();
set_default_device(Device::cpu);
auto fft_fn = [](array x) { return fft::fft(x); };
auto x = reshape(arange(8), {2, 4});
auto y = vmap(fft_fn)(x);
@@ -252,14 +236,9 @@ TEST_CASE("test fft vmap") {
y = vmap(rfft_fn, 1, 1)(x);
CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>());
set_default_device(device);
}
TEST_CASE("test fft grads") {
auto device = default_device();
set_default_device(Device::cpu);
// Regular
auto fft_fn = [](array x) { return fft::fft(x); };
auto cotangent = astype(arange(10), complex64);
@@ -328,6 +307,4 @@ TEST_CASE("test fft grads") {
zeros({5, 8}))
.second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
set_default_device(device);
}