mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
awni's commit files
This commit is contained in:
1041
python/tests/test_array.py
Normal file
1041
python/tests/test_array.py
Normal file
File diff suppressed because it is too large
Load Diff
263
python/tests/test_autograd.py
Normal file
263
python/tests/test_autograd.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestAutograd(mlx_tests.MLXTestCase):
|
||||
def test_jvp(self):
|
||||
fun = lambda x: 2 * x
|
||||
out, dout = mx.jvp(fun, [mx.array(1.0)], [mx.array(2.0)])
|
||||
self.assertEqual(out[0].item(), 2.0)
|
||||
self.assertEqual(dout[0].item(), 4.0)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
_, out = mx.jvp(
|
||||
fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0), mx.array(2.0)]
|
||||
)
|
||||
self.assertEqual(out[0].item(), 4.0 * 2.0 + 2.0 * 3.0)
|
||||
|
||||
fun = lambda x, y, z: (x * y, y * z)
|
||||
_, out = mx.jvp(
|
||||
fun,
|
||||
[mx.array(2.0), mx.array(4.0), mx.array(6.0)],
|
||||
[mx.array(1.0), mx.array(3.0), mx.array(1.0)],
|
||||
)
|
||||
self.assertEqual(len(out), 2)
|
||||
self.assertEqual(out[0].item(), 4.0 * 1.0 + 2.0 * 3.0)
|
||||
self.assertEqual(out[1].item(), 4.0 * 1.0 + 6.0 * 3.0)
|
||||
|
||||
def test_vjp(self):
|
||||
fun = lambda x: 2 * x
|
||||
out, dout = mx.vjp(fun, [mx.array(1.0)], [mx.array(2.0)])
|
||||
self.assertEqual(out[0].item(), 2.0)
|
||||
self.assertEqual(dout[0].item(), 4.0)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
_, dout = mx.vjp(fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0)])
|
||||
self.assertEqual(dout[0].item(), 6.0)
|
||||
self.assertEqual(dout[1].item(), 12.0)
|
||||
|
||||
fun = lambda x, y, z: (x * y, y * z)
|
||||
_, out = mx.vjp(
|
||||
fun,
|
||||
[mx.array(2.0), mx.array(4.0), mx.array(6.0)],
|
||||
[mx.array(1.0), mx.array(3.0)],
|
||||
)
|
||||
self.assertEqual(len(out), 3)
|
||||
self.assertEqual(out[0].item(), 4.0 * 1.0)
|
||||
self.assertEqual(out[1].item(), 2.0 * 1.0 + 6.0 * 3.0)
|
||||
self.assertEqual(out[2].item(), 4.0 * 3.0)
|
||||
|
||||
def test_grad(self):
|
||||
fun = lambda x: x * x
|
||||
|
||||
value, dfdx = mx.value_and_grad(fun)(mx.array(0.5))
|
||||
self.assertEqual(value.item(), 0.25)
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array(0.5))
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
|
||||
df2dx2 = mx.grad(mx.grad(fun))(mx.array(0.5))
|
||||
self.assertEqual(df2dx2.item(), 2.0)
|
||||
df3dx3 = mx.grad(mx.grad(mx.grad(fun)))(mx.array(0.5))
|
||||
self.assertEqual(df3dx3.item(), 0.0)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
x = mx.array(2.0)
|
||||
y = mx.array(3.0)
|
||||
dfdx = mx.grad(fun, argnums=0)(x, y)
|
||||
self.assertEqual(dfdx.item(), 3.0)
|
||||
dfdx = mx.grad(fun, argnums=1)(x, y)
|
||||
self.assertEqual(dfdx.item(), 2.0)
|
||||
|
||||
# Pass non array args to functions works
|
||||
fun = lambda x, y: x
|
||||
value, dfdx = mx.value_and_grad(fun)(mx.array(2.0), "hello")
|
||||
self.assertEqual(value.item(), 2.0)
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array(2.0), "hello")
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
|
||||
# Raises when function does not return array
|
||||
fun = lambda x: "hello"
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun)(mx.array(2.0))
|
||||
|
||||
# Raises for invalid argument number or argument type
|
||||
fun = lambda x: x
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun, argnums=2)(mx.array(2.0))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun, argnums=-2)(mx.array(2.0))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun)("hello")
|
||||
|
||||
# Raises when output is not a scalar array
|
||||
fun = lambda x: mx.sum(x, keepdims=True)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun)(mx.ones((2, 2)))
|
||||
|
||||
def test_grad_trees(self):
|
||||
fun = lambda x, y: x * y
|
||||
value, dfdx = mx.value_and_grad(fun, (0, 1))(mx.array(0.5), mx.array(2.0))
|
||||
self.assertEqual(value.item(), 1.0)
|
||||
self.assertTrue(isinstance(dfdx, tuple))
|
||||
self.assertEqual(dfdx[0].item(), 2.0)
|
||||
self.assertEqual(dfdx[1].item(), 0.5)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
value, dfdx = mx.value_and_grad(fun, 1)(mx.array(0.5), mx.array(2.0))
|
||||
self.assertEqual(value.item(), 1.0)
|
||||
self.assertEqual(dfdx.item(), 0.5)
|
||||
|
||||
fun = lambda p: p["x"] * p["y"]
|
||||
value, dfdx = mx.value_and_grad(fun)({"x": mx.array(0.5), "y": mx.array(2.0)})
|
||||
self.assertEqual(value.item(), 1.0)
|
||||
self.assertEqual(dfdx["x"].item(), 2.0)
|
||||
self.assertEqual(dfdx["y"].item(), 0.5)
|
||||
|
||||
fun = lambda p: p["x"] * p["y"]
|
||||
with self.assertRaises(ValueError):
|
||||
mx.value_and_grad(fun)({"x": 0.5, "y": mx.array(2.0)})
|
||||
with self.assertRaises(ValueError):
|
||||
mx.value_and_grad(fun, (0, 1))({"x": mx.array(0.5), "y": mx.array(2.0)})
|
||||
|
||||
fun = lambda p, b: mx.square(p[0]["foo"][2]) * b
|
||||
value, dfdx = mx.value_and_grad(fun)(
|
||||
[{"foo": [[], [], mx.array(2.0)]}], mx.array(0.5)
|
||||
)
|
||||
self.assertEqual(value.item(), 2.0)
|
||||
self.assertEqual(dfdx[0]["foo"][2].item(), 2.0)
|
||||
|
||||
fun = lambda x: x
|
||||
with self.assertRaises(TypeError):
|
||||
mx.value_and_grad(fun, (None, None))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.value_and_grad(fun, tuple())
|
||||
|
||||
def test_auxiliary_values(self):
|
||||
def fun(x, y):
|
||||
l = (x * y).sum()
|
||||
extra = {"loss": l, "foo": y.square() + x.square(), "bar": [1, 2, 3, y, x]}
|
||||
return l, extra
|
||||
|
||||
fun_value_grad = mx.value_and_grad(fun)
|
||||
fun_grad = mx.grad(fun)
|
||||
|
||||
(loss, a), b = fun_value_grad(mx.ones((2, 2)), mx.ones((2, 2)))
|
||||
self.assertEqual(a["loss"].item(), 4)
|
||||
self.assertTrue(mx.array_equal(b, mx.ones((2, 2))))
|
||||
self.assertTrue(mx.array_equal(a["foo"], 2 * mx.ones((2, 2))))
|
||||
self.assertEqual(a["bar"][:3], [1, 2, 3])
|
||||
self.assertTrue(mx.array_equal(a["bar"][3], mx.ones((2, 2))))
|
||||
self.assertTrue(mx.array_equal(a["bar"][4], mx.ones((2, 2))))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = fun_grad(mx.ones((2, 2)), mx.ones((2, 2)))
|
||||
|
||||
def test_grad_kwargs(self):
|
||||
fun = lambda x, y: x * y
|
||||
a, b = mx.array(0.5), mx.array(2.0)
|
||||
dfdx = mx.grad(fun)
|
||||
self.assertEqual(dfdx(a, b).item(), 2.0)
|
||||
self.assertEqual(dfdx(a, y=b).item(), 2.0)
|
||||
with self.assertRaises(ValueError):
|
||||
dfdx(x=a, y=b).item()
|
||||
|
||||
dfdy = mx.grad(fun, argnums=[], argnames=["y"])
|
||||
with self.assertRaises(ValueError):
|
||||
dfdy(a, b)
|
||||
grads = dfdy(a, y=b)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertTrue(grads[0] is None)
|
||||
self.assertTrue(isinstance(grads[1], dict))
|
||||
self.assertEqual(grads[1]["y"].item(), 0.5)
|
||||
grads = dfdy(x=a, y=b)
|
||||
self.assertEqual(grads[1]["y"].item(), 0.5)
|
||||
self.assertEqual(len(grads[1]), 1)
|
||||
|
||||
dfdxy = mx.grad(fun, argnums=[0], argnames=["y"])
|
||||
with self.assertRaises(ValueError):
|
||||
dfdxy(a, b)
|
||||
with self.assertRaises(ValueError):
|
||||
dfdxy(x=a, y=b)
|
||||
grads = dfdxy(a, y=b)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertEqual(grads[0].item(), 2.0)
|
||||
self.assertTrue(isinstance(grads[1], dict))
|
||||
self.assertEqual(grads[1]["y"].item(), 0.5)
|
||||
|
||||
fun = lambda x, y, z: x * y * z
|
||||
dfdxyz = mx.grad(fun, argnums=[0, 1], argnames=["z"])
|
||||
c = mx.array(4.0)
|
||||
grads = dfdxyz(a, b, z=c)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertTrue(isinstance(grads[0], tuple))
|
||||
self.assertEqual(grads[0][0].item(), 8.0)
|
||||
self.assertEqual(grads[0][1].item(), 2.0)
|
||||
self.assertTrue(isinstance(grads[1], dict))
|
||||
self.assertEqual(grads[1]["z"].item(), 1.0)
|
||||
|
||||
fun = lambda x, y: x * y
|
||||
dfdy = mx.grad(fun, argnames=["y"])
|
||||
grads = dfdy(a, y=b)
|
||||
self.assertTrue(isinstance(grads, tuple))
|
||||
self.assertTrue(grads[0] is None)
|
||||
self.assertTrue(isinstance(grads[1], dict))
|
||||
self.assertEqual(grads[1]["y"].item(), 0.5)
|
||||
|
||||
def test_captured(self):
|
||||
a = mx.array(5.0)
|
||||
f = lambda x: a + x
|
||||
g = lambda x: a + a
|
||||
h = lambda x: x + x
|
||||
|
||||
dfdx = mx.grad(f)
|
||||
self.assertEqual(dfdx(a).item(), 1.0)
|
||||
|
||||
dgdx = mx.grad(g)
|
||||
self.assertEqual(dgdx(a).item(), 0.0)
|
||||
|
||||
dhdx = mx.grad(h)
|
||||
self.assertEqual(dhdx(a).item(), 2.0)
|
||||
|
||||
d2fdx2 = mx.grad(dfdx)
|
||||
self.assertEqual(d2fdx2(a).item(), 0.0)
|
||||
|
||||
d2gdx2 = mx.grad(dgdx)
|
||||
self.assertEqual(d2gdx2(a).item(), 0.0)
|
||||
|
||||
d2hdx2 = mx.grad(dhdx)
|
||||
self.assertEqual(d2hdx2(a).item(), 0.0)
|
||||
|
||||
def test_stop_gradient(self):
|
||||
shape_in = (4, 4)
|
||||
w_in = mx.ones(shape_in)
|
||||
x_in = mx.ones(shape_in)
|
||||
cotan = mx.ones(shape_in)
|
||||
|
||||
def h(w, x):
|
||||
x1 = 2 * x
|
||||
y = mx.stop_gradient(x1)
|
||||
y1 = 3 * y
|
||||
return w @ y1
|
||||
|
||||
vals, vjps = mx.vjp(h, [w_in, x_in], [cotan])
|
||||
mx.eval(vjps)
|
||||
|
||||
self.assertTrue(mx.allclose(vjps[0], 24.0 * mx.ones(shape_in)))
|
||||
self.assertTrue(mx.allclose(vjps[1], mx.zeros(shape_in)))
|
||||
|
||||
g = lambda x: h(w_in, x)
|
||||
vals, vjps = mx.vjp(g, [x_in], [cotan])
|
||||
mx.eval(vjps)
|
||||
|
||||
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
105
python/tests/test_device.py
Normal file
105
python/tests/test_device.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
# Don't inherit from MLXTestCase to avoid call to setUp
|
||||
class TestDefaultDevice(unittest.TestCase):
|
||||
def test_mlx_default_device(self):
|
||||
device = mx.default_device()
|
||||
if mx.metal.is_available():
|
||||
self.assertEqual(device, mx.Device(mx.gpu))
|
||||
self.assertEqual(str(device), "Device(gpu, 0)")
|
||||
self.assertEqual(device, mx.gpu)
|
||||
self.assertEqual(mx.gpu, device)
|
||||
else:
|
||||
self.assertEqual(device.type, mx.Device(mx.cpu))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.set_default_device(mx.gpu)
|
||||
|
||||
|
||||
class TestDevice(mlx_tests.MLXTestCase):
|
||||
def test_device(self):
|
||||
device = mx.default_device()
|
||||
|
||||
cpu = mx.Device(mx.cpu)
|
||||
mx.set_default_device(cpu)
|
||||
self.assertEqual(mx.default_device(), cpu)
|
||||
self.assertEqual(str(cpu), "Device(cpu, 0)")
|
||||
|
||||
mx.set_default_device(mx.cpu)
|
||||
self.assertEqual(mx.default_device(), mx.cpu)
|
||||
self.assertEqual(cpu, mx.cpu)
|
||||
self.assertEqual(mx.cpu, cpu)
|
||||
|
||||
# Restore device
|
||||
mx.set_default_device(device)
|
||||
|
||||
def test_op_on_device(self):
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
a = mx.add(x, y, stream=None)
|
||||
b = mx.add(x, y, stream=mx.default_device())
|
||||
self.assertEqual(a.item(), b.item())
|
||||
b = mx.add(x, y, stream=mx.cpu)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
if mx.metal.is_available():
|
||||
b = mx.add(x, y, stream=mx.gpu)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
|
||||
class TestStream(mlx_tests.MLXTestCase):
|
||||
def test_stream(self):
|
||||
s1 = mx.default_stream(mx.default_device())
|
||||
self.assertEqual(s1.device, mx.default_device())
|
||||
|
||||
s2 = mx.new_stream(mx.default_device())
|
||||
self.assertEqual(s2.device, mx.default_device())
|
||||
self.assertNotEqual(s1, s2)
|
||||
|
||||
if mx.metal.is_available():
|
||||
s_gpu = mx.default_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
mx.default_stream(mx.gpu)
|
||||
|
||||
s_cpu = mx.default_stream(mx.cpu)
|
||||
self.assertEqual(s_cpu.device, mx.cpu)
|
||||
|
||||
s_cpu = mx.new_stream(mx.cpu)
|
||||
self.assertEqual(s_cpu.device, mx.cpu)
|
||||
|
||||
if mx.metal.is_available():
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
mx.new_stream(mx.gpu)
|
||||
|
||||
def test_op_on_stream(self):
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
||||
|
||||
if mx.metal.is_available():
|
||||
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
||||
self.assertEqual(a.item(), b.item())
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
b = mx.add(x, y, stream=s_gpu)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
b = mx.add(x, y, stream=mx.default_stream(mx.cpu))
|
||||
self.assertEqual(a.item(), b.item())
|
||||
s_cpu = mx.new_stream(mx.cpu)
|
||||
b = mx.add(x, y, stream=s_cpu)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
34
python/tests/test_eval.py
Normal file
34
python/tests/test_eval.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from functools import partial
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestEval(mlx_tests.MLXTestCase):
|
||||
def test_eval(self):
|
||||
arrs = [mx.ones((2, 2)) for _ in range(4)]
|
||||
mx.eval(*arrs)
|
||||
for x in arrs:
|
||||
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
|
||||
|
||||
def test_retain_graph(self):
|
||||
def fun(x, retain_graph):
|
||||
y = 3 * x
|
||||
mx.eval(y, retain_graph=retain_graph)
|
||||
return 2 * y
|
||||
|
||||
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
|
||||
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
dfun_dx_1(mx.array(1.0))
|
||||
|
||||
y = dfun_dx_2(mx.array(1.0))
|
||||
self.assertEqual(y.item(), 6.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
90
python/tests/test_fft.py
Normal file
90
python/tests/test_fft.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import unittest
|
||||
|
||||
import itertools
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestFFT(mlx_tests.MLXTestCase):
|
||||
def check_mx_np(self, op, a_np, axes, s):
|
||||
with self.subTest(op=op, axes=axes, s=s):
|
||||
op_np = getattr(np.fft, op)
|
||||
op_mx = getattr(mx.fft, op)
|
||||
out_np = op_np(a_np, s=s, axes=axes)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, s=s, axes=axes)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
def test_fft(self):
|
||||
default = mx.default_device()
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def check_mx_np(op_mx, op_np, a_np, **kwargs):
|
||||
out_np = op_np(a_np, **kwargs)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, **kwargs)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
|
||||
mx.set_default_device(default)
|
||||
|
||||
def test_fftn(self):
|
||||
default = mx.default_device()
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
r = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
i = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
a = r + 1j * i
|
||||
|
||||
axes = [None, (1, 2), (2, 1), (0, 2)]
|
||||
shapes = [None, (10, 5), (5, 10)]
|
||||
ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"]
|
||||
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
self.check_mx_np(op, x, axes=ax, s=s)
|
||||
|
||||
mx.set_default_device(default)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
1283
python/tests/test_ops.py
Normal file
1283
python/tests/test_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
118
python/tests/test_reduce.py
Normal file
118
python/tests/test_reduce.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import unittest
|
||||
from itertools import permutations, combinations
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestReduce(mlx_tests.MLXTestCase):
|
||||
def test_axis_permutation_sums(self):
|
||||
x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32)
|
||||
x_mlx = mx.array(x_npy)
|
||||
for t in permutations(range(5)):
|
||||
with self.subTest(t=t):
|
||||
y_npy = np.transpose(x_npy, t)
|
||||
y_mlx = mx.transpose(x_mlx, t)
|
||||
for n in range(1, 6):
|
||||
for a in combinations(range(5), n):
|
||||
with self.subTest(a=a):
|
||||
z_npy = np.sum(y_npy, axis=a)
|
||||
z_mlx = mx.sum(y_mlx, axis=a)
|
||||
mx.eval(z_mlx)
|
||||
self.assertTrue(
|
||||
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
||||
)
|
||||
|
||||
def test_expand_sums(self):
|
||||
x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32)
|
||||
x_mlx = mx.array(x_npy)
|
||||
for m in range(1, 4):
|
||||
for ax in combinations([1, 3, 5], m):
|
||||
shape = np.array([5, 1, 5, 1, 5, 1])
|
||||
shape[list(ax)] = 5
|
||||
shape = shape.tolist()
|
||||
with self.subTest(shape=shape):
|
||||
y_npy = np.broadcast_to(x_npy, shape)
|
||||
y_mlx = mx.broadcast_to(x_mlx, shape)
|
||||
for n in range(1, 7):
|
||||
for a in combinations(range(6), n):
|
||||
with self.subTest(a=a):
|
||||
z_npy = np.sum(y_npy, axis=a) / 1000
|
||||
z_mlx = mx.sum(y_mlx, axis=a) / 1000
|
||||
mx.eval(z_mlx)
|
||||
self.assertTrue(
|
||||
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
||||
)
|
||||
|
||||
def test_dtypes(self):
|
||||
int_dtypes = [
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
]
|
||||
float_dtypes = ["float32"]
|
||||
|
||||
for dtype in int_dtypes + float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x = np.random.uniform(0, 2, size=(3, 3, 3)).astype(getattr(np, dtype))
|
||||
y = mx.array(x)
|
||||
|
||||
for op in ("sum", "prod", "min", "max"):
|
||||
with self.subTest(op=op):
|
||||
|
||||
np_op = getattr(np, op)
|
||||
mlx_op = getattr(mx, op)
|
||||
|
||||
for axes in (None, 0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
|
||||
with self.subTest(axes=axes):
|
||||
if op in ("sum", "prod"):
|
||||
r_np = np_op(
|
||||
x, axis=axes, dtype=(getattr(np, dtype))
|
||||
)
|
||||
else:
|
||||
r_np = np_op(x, axis=axes)
|
||||
r_mlx = mlx_op(y, axis=axes)
|
||||
mx.eval(r_mlx)
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=1e-4))
|
||||
|
||||
def test_arg_reduce(self):
|
||||
dtypes = [
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"float16",
|
||||
"float32",
|
||||
]
|
||||
for dtype in dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
|
||||
data = np.random.rand(10, 12, 13).astype(getattr(np, dtype))
|
||||
x = mx.array(data)
|
||||
for op in ["argmin", "argmax"]:
|
||||
for axis in range(3):
|
||||
for kd in [True, False]:
|
||||
a = getattr(mx, op)(x, axis, kd)
|
||||
b = getattr(np, op)(data, axis, keepdims=kd)
|
||||
self.assertEqual(a.tolist(), b.tolist())
|
||||
|
||||
for op in ["argmin", "argmax"]:
|
||||
a = getattr(mx, op)(x, keepdims=True)
|
||||
b = getattr(np, op)(data, keepdims=True)
|
||||
self.assertEqual(a.tolist(), b.tolist())
|
||||
a = getattr(mx, op)(x)
|
||||
b = getattr(np, op)(data)
|
||||
self.assertEqual(a.item(), b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
Reference in New Issue
Block a user