mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 19:11:17 +08:00
Metal FFT for powers of 2 up to 2048 (#915)
* add Metal FFT for powers of 2 * skip GPU test on linux * fix contiguity bug * address comments * Update mlx/backend/metal/fft.cpp * Update mlx/backend/metal/fft.cpp * fix bug in synch --------- Co-authored-by: Alex Barron <abarron22@apple.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
ae18326533
commit
2e7c02d5cd
57
benchmarks/python/fft_bench.py
Normal file
57
benchmarks/python/fft_bench.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def bandwidth_gb(runtime_ms, system_size):
|
||||
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
|
||||
bytes_per_gb = 1e9
|
||||
ms_per_s = 1e3
|
||||
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||
|
||||
|
||||
def run_bench(system_size):
|
||||
def fft(x):
|
||||
out = mx.fft.fft(x)
|
||||
mx.eval(out)
|
||||
return out
|
||||
|
||||
bandwidths = []
|
||||
for k in range(4, 12):
|
||||
n = 2**k
|
||||
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
|
||||
x = x.astype(mx.complex64)
|
||||
mx.eval(x)
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
|
||||
|
||||
return bandwidths
|
||||
|
||||
|
||||
def time_fft():
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=int(2**29))
|
||||
|
||||
# plot bandwidths
|
||||
x = [2**k for k in range(4, 12)]
|
||||
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
|
||||
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
|
||||
plt.title("MLX FFT Benchmark")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig("fft_plot.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_fft()
|
@ -1,12 +1,106 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& in = inputs[0];
|
||||
throw std::runtime_error("[FFT] NYI for Metal backend.");
|
||||
|
||||
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
|
||||
in.dtype() != complex64 || out.dtype() != complex64) {
|
||||
// Could also fallback to CPU implementation here.
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
|
||||
}
|
||||
|
||||
size_t n = in.shape(axes_[0]);
|
||||
|
||||
if (!is_power_of_2(n) || n > 2048 || n < 4) {
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
|
||||
}
|
||||
|
||||
// Make sure that the array is contiguous and has stride 1 in the FFT dim
|
||||
std::vector<array> copies;
|
||||
auto check_input = [this, &copies, &s](const array& x) {
|
||||
// TODO: Pass the strides to the kernel so
|
||||
// we can avoid the copy when x is not contiguous.
|
||||
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
|
||||
x.flags().col_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axes_[0]);
|
||||
for (int axis = 0; axis < x.ndim(); axis++) {
|
||||
if (axis == axes_[0]) {
|
||||
strides.push_back(1);
|
||||
} else {
|
||||
strides.push_back(cur_stride);
|
||||
cur_stride *= x.shape(axis);
|
||||
}
|
||||
}
|
||||
|
||||
auto flags = x.flags();
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
|
||||
f_stride *= x.shape(i);
|
||||
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
|
||||
b_stride *= x.shape(ri);
|
||||
}
|
||||
// This is probably over-conservative
|
||||
flags.contiguous = false;
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(inputs[0]);
|
||||
|
||||
// TODO: allow donation here
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
in_contiguous.data_size(),
|
||||
in_contiguous.strides(),
|
||||
in_contiguous.flags());
|
||||
|
||||
// We use n / 4 threads by default since radix-4
|
||||
// is the largest single threaded radix butterfly
|
||||
// we currently implement.
|
||||
size_t m = n / 4;
|
||||
size_t batch = in.size() / in.shape(axes_[0]);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "fft_" << n;
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -21,6 +21,7 @@ set(
|
||||
"binary_two"
|
||||
"conv"
|
||||
"copy"
|
||||
"fft"
|
||||
"gemv"
|
||||
"quantized"
|
||||
"random"
|
||||
|
195
mlx/backend/metal/kernels/fft.metal
Normal file
195
mlx/backend/metal/kernels/fft.metal
Normal file
@ -0,0 +1,195 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Metal FFT using Stockham's algorithm
|
||||
//
|
||||
// References:
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
|
||||
#include <metal_math>
|
||||
#include <metal_common>
|
||||
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
float2 complex_mul(float2 a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a.x * b.x - a.y * b.y;
|
||||
c.y = a.x * b.y + a.y * b.x;
|
||||
return c;
|
||||
}
|
||||
|
||||
float2 get_twiddle(int k, int p) {
|
||||
float theta = -1.0f * k * M_PI_F / (2*p);
|
||||
|
||||
float2 twiddle;
|
||||
twiddle.x = metal::fast::cos(theta);
|
||||
twiddle.y = metal::fast::sin(theta);
|
||||
return twiddle;
|
||||
}
|
||||
|
||||
// single threaded radix2 implemetation
|
||||
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
|
||||
float2 z = complex_mul(x_1, twiddle);
|
||||
|
||||
float2 y_0 = x_0 + z;
|
||||
float2 y_1 = x_0 - z;
|
||||
|
||||
int j = (i << 1) - k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
}
|
||||
|
||||
// single threaded radix4 implemetation
|
||||
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
float2 x_2 = read_buf[i + 2*m];
|
||||
float2 x_3 = read_buf[i + 3*m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
// e^a * e^b = e^(a + b)
|
||||
float2 twiddle_2 = complex_mul(twiddle, twiddle);
|
||||
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
|
||||
|
||||
x_1 = complex_mul(x_1, twiddle);
|
||||
x_2 = complex_mul(x_2, twiddle_2);
|
||||
x_3 = complex_mul(x_3, twiddle_3);
|
||||
|
||||
float2 minus_i;
|
||||
minus_i.x = 0;
|
||||
minus_i.y = -1;
|
||||
|
||||
// Hard coded twiddle factors for DFT4
|
||||
float2 z_0 = x_0 + x_2;
|
||||
float2 z_1 = x_0 - x_2;
|
||||
float2 z_2 = x_1 + x_3;
|
||||
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
|
||||
|
||||
float2 y_0 = z_0 + z_2;
|
||||
float2 y_1 = z_1 + z_3;
|
||||
float2 y_2 = z_0 - z_2;
|
||||
float2 y_3 = z_1 - z_3;
|
||||
|
||||
int j = ((i - k) << 2) + k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
write_buf[j + 2*p] = y_2;
|
||||
write_buf[j + 3*p] = y_3;
|
||||
}
|
||||
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-2 and radix-4 DFTs:
|
||||
// e.g. 128 = 2 * 4 * 4 * 4
|
||||
//
|
||||
// At each step we use n / 4 threads, each performing
|
||||
// a single-threaded radix-4 or radix-2 DFT.
|
||||
//
|
||||
// We provide the number of radix-2 and radix-4
|
||||
// steps at compile time for a ~20% performance boost.
|
||||
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||
[[kernel]] void fft(
|
||||
const device float2 *in [[buffer(0)]],
|
||||
device float2 * out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
||||
uint3 threads_per_grid [[threads_per_grid]]) {
|
||||
|
||||
// Index of the DFT in batch
|
||||
int batch_idx = thread_position_in_grid.x * n;
|
||||
// The index in the DFT we're working on
|
||||
int i = thread_position_in_grid.y;
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = threads_per_grid.y;
|
||||
|
||||
// Allocate 2 shared memory buffers for Stockham.
|
||||
// We alternate reading from one and writing to the other at each radix step.
|
||||
threadgroup float2 shared_in[n];
|
||||
threadgroup float2 shared_out[n];
|
||||
|
||||
// Pointers to facilitate Stockham buffer swapping
|
||||
threadgroup float2* read_buf = shared_in;
|
||||
threadgroup float2* write_buf = shared_out;
|
||||
threadgroup float2* tmp;
|
||||
|
||||
// Copy input into shared memory
|
||||
shared_in[i] = in[batch_idx + i];
|
||||
shared_in[i + m] = in[batch_idx + i + m];
|
||||
shared_in[i + 2*m] = in[batch_idx + i + 2*m];
|
||||
shared_in[i + 3*m] = in[batch_idx + i + 3*m];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
|
||||
for (size_t r = 0; r < radix_2_steps; r++) {
|
||||
radix2(i, p, m*2, read_buf, write_buf);
|
||||
radix2(i + m, p, m*2, read_buf, write_buf);
|
||||
p *= 2;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
for (size_t r = 0; r < radix_4_steps; r++) {
|
||||
radix4(i, p, m, read_buf, write_buf);
|
||||
p *= 4;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
// Copy shared memory to output
|
||||
out[batch_idx + i] = read_buf[i];
|
||||
out[batch_idx + i + m] = read_buf[i + m];
|
||||
out[batch_idx + i + 2*m] = read_buf[i + 2*m];
|
||||
out[batch_idx + i + 3*m] = read_buf[i + 3*m];
|
||||
}
|
||||
|
||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||
template [[host_name("fft_" #name)]] \
|
||||
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
|
||||
const device float2* in [[buffer(0)]], \
|
||||
device float2* out [[buffer(1)]], \
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||
uint3 threads_per_grid [[threads_per_grid]]);
|
||||
|
||||
|
||||
// Explicitly define kernels for each power of 2.
|
||||
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
||||
instantiate_fft(8, 8, 1, 1)
|
||||
instantiate_fft(16, 16, 0, 2)
|
||||
instantiate_fft(32, 32, 1, 2)
|
||||
instantiate_fft(64, 64, 0, 3)
|
||||
instantiate_fft(128, 128, 1, 3)
|
||||
instantiate_fft(256, 256, 0, 4)
|
||||
instantiate_fft(512, 512, 1, 4)
|
||||
instantiate_fft(1024, 1024, 0, 5)
|
||||
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
||||
// TODO: implement 4 step FFT for larger n.
|
||||
instantiate_fft(2048, 2048, 1, 5)
|
@ -130,6 +130,10 @@ inline void debug_set_primitive_buffer_label(
|
||||
#endif
|
||||
}
|
||||
|
||||
bool is_power_of_2(int n) {
|
||||
return ((n & (n - 1)) == 0) && n != 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -9,58 +9,49 @@ import numpy as np
|
||||
|
||||
|
||||
class TestFFT(mlx_tests.MLXTestCase):
|
||||
def check_mx_np(self, op, a_np, axes, s):
|
||||
with self.subTest(op=op, axes=axes, s=s):
|
||||
op_np = getattr(np.fft, op)
|
||||
op_mx = getattr(mx.fft, op)
|
||||
out_np = op_np(a_np, s=s, axes=axes)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, s=s, axes=axes)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs):
|
||||
out_np = op_np(a_np, **kwargs)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, **kwargs)
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fft(self):
|
||||
def check_mx_np(op_mx, op_np, a_np, **kwargs):
|
||||
out_np = op_np(a_np, **kwargs)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, **kwargs)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
|
||||
def test_fftn(self):
|
||||
with mx.stream(mx.cpu):
|
||||
@ -85,7 +76,65 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
self.check_mx_np(op, x, axes=ax, s=s)
|
||||
mx_op = getattr(mx.fft, op)
|
||||
np_op = getattr(np.fft, op)
|
||||
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
||||
|
||||
def test_fft_powers_of_two(self):
|
||||
shape = (16, 4, 8)
|
||||
# np.fft.fft always uses double precision complex128
|
||||
# mx.fft.fft only supports single precision complex64
|
||||
# hence the fairly tolerant equality checks.
|
||||
atol = 1e-4
|
||||
rtol = 1e-4
|
||||
np.random.seed(7)
|
||||
for k in range(4, 12):
|
||||
r = np.random.rand(*shape, 2**k).astype(np.float32)
|
||||
i = np.random.rand(*shape, 2**k).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
|
||||
|
||||
r = np.random.rand(*shape, 32).astype(np.float32)
|
||||
i = np.random.rand(*shape, 32).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
for axis in range(4):
|
||||
self.check_mx_np(
|
||||
mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol, axis=axis
|
||||
)
|
||||
|
||||
r = np.random.rand(4, 8).astype(np.float32)
|
||||
i = np.random.rand(4, 8).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
def test_fft_contiguity(self):
|
||||
r = np.random.rand(4, 8).astype(np.float32)
|
||||
i = np.random.rand(4, 8).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
# non-contiguous in the FFT dim
|
||||
out_mx = mx.fft.fft(a_mx[:, ::2])
|
||||
out_np = np.fft.fft(a_np[:, ::2])
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# non-contiguous not in the FFT dim
|
||||
out_mx = mx.fft.fft(a_mx[::2])
|
||||
out_np = np.fft.fft(a_np[::2])
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)
|
||||
|
||||
out_mx = mx.broadcast_to(mx.reshape(mx.transpose(a_mx), (4, 8, 1)), (4, 8, 16))
|
||||
out_np = np.broadcast_to(np.reshape(np.transpose(a_np), (4, 8, 1)), (4, 8, 16))
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)
|
||||
|
||||
out2_mx = mx.fft.fft(mx.abs(out_mx) + 4)
|
||||
out2_np = np.fft.fft(np.abs(out_np) + 4)
|
||||
np.testing.assert_allclose(out2_mx, out2_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
b_np = np.array([[0, 1, 2, 3]])
|
||||
out_mx = mx.abs(mx.fft.fft(mx.tile(mx.reshape(mx.array(b_np), (1, 4)), (4, 1))))
|
||||
out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1))))
|
||||
np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user