mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 17:44:38 +08:00
fix scatter + test (#1202)
* fix scatter + test * fix test warnings * fix metal validation
This commit is contained in:
@@ -49,7 +49,7 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
|
||||
x = np.fft.rfft(a_np)
|
||||
x = np.fft.rfft(np.real(a_np))
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
|
||||
|
||||
def test_fftn(self):
|
||||
@@ -75,9 +75,9 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
elif op == "irfft2":
|
||||
x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s))
|
||||
x = np.ascontiguousarray(np.fft.rfft2(r, axes=ax, s=s))
|
||||
elif op == "irfftn":
|
||||
x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s))
|
||||
x = np.ascontiguousarray(np.fft.rfftn(r, axes=ax, s=s))
|
||||
mx_op = getattr(mx.fft, op)
|
||||
np_op = getattr(np.fft, op)
|
||||
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
||||
@@ -93,7 +93,7 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
|
||||
|
||||
ia_np = np.fft.rfft(a_np)
|
||||
ia_np = np.fft.rfft(r)
|
||||
self.check_mx_np(
|
||||
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
|
||||
)
|
||||
|
Reference in New Issue
Block a user