fix scatter + test (#1202)

* fix scatter + test

* fix test warnings

* fix metal validation
This commit is contained in:
Awni Hannun
2024-06-11 14:35:12 -07:00
committed by GitHub
parent 709ccc6800
commit df964132fb
5 changed files with 54 additions and 9 deletions

View File

@@ -49,7 +49,7 @@ class TestFFT(mlx_tests.MLXTestCase):
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
x = np.fft.rfft(a_np)
x = np.fft.rfft(np.real(a_np))
self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
def test_fftn(self):
@@ -75,9 +75,9 @@ class TestFFT(mlx_tests.MLXTestCase):
if op in ["rfft2", "rfftn"]:
x = r
elif op == "irfft2":
x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s))
x = np.ascontiguousarray(np.fft.rfft2(r, axes=ax, s=s))
elif op == "irfftn":
x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s))
x = np.ascontiguousarray(np.fft.rfftn(r, axes=ax, s=s))
mx_op = getattr(mx.fft, op)
np_op = getattr(np.fft, op)
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
@@ -93,7 +93,7 @@ class TestFFT(mlx_tests.MLXTestCase):
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
ia_np = np.fft.rfft(a_np)
ia_np = np.fft.rfft(r)
self.check_mx_np(
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
)