jagrit's commit files

This commit is contained in:
Jagrit Digani
2023-11-29 10:52:08 -08:00
parent d1f86272a2
commit e6306cfee9
74 changed files with 15964 additions and 2 deletions

16
python/tests/mlx_tests.py Normal file
View File

@@ -0,0 +1,16 @@
import os
import unittest
import mlx.core as mx
class MLXTestCase(unittest.TestCase):
def setUp(self):
self.default = mx.default_device()
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
mx.set_default_device(device)
def tearDown(self):
mx.set_default_device(self.default)

445
python/tests/test_blas.py Normal file
View File

@@ -0,0 +1,445 @@
import unittest
from itertools import permutations
import math
import mlx.core as mx
import numpy as np
import mlx_tests
class TestBlas(mlx_tests.MLXTestCase):
@property
def dtypes(self):
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
def __gemm_test(
self,
shape_a,
shape_b,
np_dtype=np.float32,
f_np_a=lambda x: x,
f_np_b=lambda x: x,
f_mx_a=lambda x: x,
f_mx_b=lambda x: x,
):
with self.subTest(
dtype=np.dtype(np_dtype).name, shape_a=shape_a, shape_b=shape_b
):
np.random.seed(42)
scale = max(np.sum(shape_a), 128)
a_np = np.random.normal(0.0, 1.0 / scale, shape_a).astype(np_dtype)
b_np = np.random.normal(0.0, 1.0 / scale, shape_b).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_np = f_np_a(a_np.astype(np.float32))
b_np = f_np_b(b_np.astype(np.float32))
a_mx = f_mx_a(a_mx)
b_mx = f_mx_b(b_mx)
out_npy = a_np @ b_np
out_mlx = a_mx @ b_mx
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
def test_matmul_unaligned(self):
if not mx.metal.is_available():
return
for dtype in self.dtypes:
np_dtype = getattr(np, dtype)
base_shapes = [4, 8, 16, 32, 64, 128]
pertubations = [-2, -1, 0, 1, 2]
for dim in base_shapes:
for p in pertubations:
shape_a = (dim + p, dim + p)
shape_b = (dim + p, dim + p)
self.__gemm_test(shape_a, shape_b, np_dtype)
def test_matmul_shapes(self):
if not mx.metal.is_available():
return
shapes = [
(1, 2, 1, 1),
(1, 1, 2, 1),
(3, 23, 457, 3),
]
if mx.default_device() == mx.gpu:
shapes += [
(16, 768, 768, 128),
]
for dtype in self.dtypes:
np_dtype = getattr(np, dtype)
for B, M, N, K in shapes:
with self.subTest(tranpose="nn"):
shape_a = (B, M, K)
shape_b = (B, K, N)
self.__gemm_test(shape_a, shape_b, np_dtype)
with self.subTest(tranpose="nt"):
shape_a = (B, M, K)
shape_b = (B, N, K)
self.__gemm_test(
shape_a,
shape_b,
np_dtype,
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
)
with self.subTest(tranpose="tn"):
shape_a = (B, K, M)
shape_b = (B, K, N)
self.__gemm_test(
shape_a,
shape_b,
np_dtype,
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
)
with self.subTest(tranpose="tt"):
shape_a = (B, K, M)
shape_b = (B, N, K)
self.__gemm_test(
shape_a,
shape_b,
np_dtype,
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
)
def test_matmul(self):
# Note: so far, matmul only works with floating-point types
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
b = mx.array([[0.0, -1.0], [-3.0, 3.0]])
expected = [[-6.0, 5.0], [-12.0, 9.0]]
self.assertEqual((a @ b).tolist(), expected)
self.assertEqual(mx.matmul(a, b).tolist(), expected)
# Transposed matmul
np.random.seed(0)
a_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
d_npy = np.transpose(a_npy, (1, 0)) @ b_npy
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
d_mlx = mx.transpose(a_mlx, (1, 0)) @ b_mlx
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
def test_matmul_dtypes(self):
for dt in self.dtypes:
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
getattr(np, dt)
)
b_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
getattr(np, dt)
)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_npy = np.matmul(a_npy, b_npy, dtype=getattr(np, dt))
c_mlx = a_mlx @ b_mlx
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
def test_matmul_batched(self):
np.random.seed(0)
# Batched matmul
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
c_npy = a_npy @ b_npy
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_mlx = a_mlx @ b_mlx
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Batched and transposed matmul
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
c_npy = a_npy @ np.transpose(b_npy, (0, 2, 1))
b_mlx = mx.array(b_npy)
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 2, 1))
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Batched matmul with simple broadast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
c_npy = a_npy @ b_npy
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_mlx = a_mlx @ b_mlx
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Both operands broadcasted
d_npy = np.broadcast_to(b_npy, (5, 16, 16))
d_mlx = mx.broadcast_to(b_mlx, (5, 16, 16))
e_npy = d_npy @ d_npy
e_mlx = d_mlx @ d_mlx
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
# Batched and transposed matmul with simple broadast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_npy = a_npy @ b_npy
c_mlx = a_mlx @ b_mlx
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Test Multiheaded attention style matmul
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_npy = np.transpose(a_npy, (0, 2, 1, 3))
b_npy = np.transpose(b_npy, (0, 2, 1, 3))
a_mlx = mx.transpose(a_mlx, (0, 2, 1, 3))
b_mlx = mx.transpose(b_mlx, (0, 2, 1, 3))
c_npy = a_npy @ np.transpose(b_npy, (0, 1, 3, 2))
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 1, 3, 2))
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
def __gemv_test(
self,
shape_mat,
shape_vec,
np_dtype=np.float32,
mat_first=True,
np_mat_f=lambda x: x,
np_vec_f=lambda x: x,
mlx_mat_f=lambda x: x,
mlx_vec_f=lambda x: x,
):
with self.subTest(shape=shape_mat):
np.random.seed(42)
scale = max(np.sum(shape_mat), 32)
mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)
vec_npy = np.random.normal(0.0, 1.0 / scale, shape_vec).astype(np_dtype)
mat_mlx = mx.array(mat_npy)
vec_mlx = mx.array(vec_npy)
mat_npy = np_mat_f(mat_npy)
vec_npy = np_vec_f(vec_npy)
mat_mlx = mlx_mat_f(mat_mlx)
vec_mlx = mlx_vec_f(vec_mlx)
if mat_first:
out_npy = mat_npy @ vec_npy
out_mlx = mat_mlx @ vec_mlx
else:
out_npy = vec_npy @ mat_npy
out_mlx = vec_mlx @ mat_mlx
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
self.assertTrue(np.allclose(out_mlx, out_npy, atol=1e-5))
def test_matrix_vector(self):
for dtype in self.dtypes:
with self.subTest(dtype=dtype):
np_dtype = getattr(np, dtype)
# Basic square matrix test
self.__gemv_test(
shape_mat=(64, 64), shape_vec=(64, 1), np_dtype=np_dtype
)
self.__gemv_test(
shape_mat=(64, 64),
shape_vec=(64, 1),
np_dtype=np_dtype,
mat_first=False,
np_vec_f=lambda x: np.transpose(x, (1, 0)),
mlx_vec_f=lambda x: mx.transpose(x, (1, 0)),
)
# Vector matrix product with aligned and unaligned shapes
for in_len_base, out_len_base in (
(2, 2),
(32, 32),
(64, 64),
(2048, 2048),
):
for mi in (-1, 0, 1):
for mj in (-1, 0, 1):
# Vec mat
shape_mat = (in_len_base + mi, out_len_base + mj)
shape_vec = (1, in_len_base + mi)
self.__gemv_test(
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
)
# Mat vec
shape_mat = (out_len_base + mj, in_len_base + mi)
shape_vec = (in_len_base + mi, 1)
self.__gemv_test(
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
)
def test_matrix_vector_batched(self):
for dtype in self.dtypes:
with self.subTest(dtype=dtype):
np_dtype = getattr(np, dtype)
# Batched mat vec
for shape_mat, shape_vec in (
((32, 128, 64), (32, 64, 1)),
((128, 64), (32, 64, 1)),
((32, 128, 64), (64, 1)),
):
self.__gemv_test(
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
)
# Batched vec mat
for shape_vec, shape_mat in (
((32, 1, 128), (32, 128, 64)),
((32, 1, 128), (128, 64)),
((1, 128), (32, 128, 64)),
):
self.__gemv_test(
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
)
def test_matrix_vector_broadcast(self):
for dtype in self.dtypes:
with self.subTest(dtype=dtype):
np_dtype = getattr(np, dtype)
# Different broadcasts mat vec
for shape_mat, shape_vec in (
((32, 64, 64), (32, 64, 1)),
((64, 64), (32, 64, 1)),
((32, 64, 64), (64, 1)),
):
self.__gemv_test(
shape_mat=(64, 64),
shape_vec=(64, 1),
np_dtype=np_dtype,
np_mat_f=(lambda mat_npy: np.broadcast_to(mat_npy, shape_mat)),
np_vec_f=(lambda vec_npy: np.broadcast_to(vec_npy, shape_vec)),
mlx_mat_f=(lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat)),
mlx_vec_f=(lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec)),
)
# Different broadcasts vec mat
for shape_vec, shape_mat in (
((32, 1, 64), (32, 64, 64)),
((32, 1, 64), (64, 64)),
((1, 64), (32, 64, 64)),
):
self.__gemv_test(
shape_mat=(64, 64),
shape_vec=(1, 64),
np_dtype=np_dtype,
mat_first=False,
np_mat_f=lambda mat_npy: np.broadcast_to(mat_npy, shape_mat),
np_vec_f=lambda vec_npy: np.broadcast_to(vec_npy, shape_vec),
mlx_mat_f=lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat),
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
)
def test_matrix_vector_edgecases(self):
for dtype in self.dtypes:
with self.subTest(dtype=dtype):
np_dtype = getattr(np, dtype)
for in_vec_len in np.arange(1, 5):
for out_vec_len in np.arange(1, 5):
for batch_size in np.arange(1, 5):
with self.subTest(
problem_shape=(batch_size, in_vec_len, out_vec_len)
):
# Matrix vector
with self.subTest(transpose=False):
a_npy = np.ones(
(batch_size, out_vec_len, in_vec_len),
dtype=np_dtype,
)
b_npy = np.ones(
(batch_size, in_vec_len, 1), dtype=np_dtype
)
for i in range(batch_size):
b_npy[i] *= i + 1.0
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
c_npy = a_npy @ b_npy
c_mlx = a_mlx @ b_mlx
self.assertListEqual(
list(c_npy.shape), list(c_mlx.shape)
)
self.assertTrue(np.array_equal(c_mlx, c_npy))
# Vector matrix
with self.subTest(transpose=True):
a_npy = np.ones(
(batch_size, out_vec_len, in_vec_len),
dtype=np_dtype,
)
b_npy = np.ones(
(batch_size, 1, out_vec_len), dtype=np_dtype
)
for i in range(batch_size):
b_npy[i] *= i + 1.0
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
c_npy = b_npy @ a_npy
c_mlx = b_mlx @ a_mlx
self.assertListEqual(
list(c_npy.shape), list(c_mlx.shape)
)
self.assertTrue(np.array_equal(c_mlx, c_npy))

445
python/tests/test_conv.py Normal file
View File

@@ -0,0 +1,445 @@
import unittest
from itertools import permutations
import math
import mlx.core as mx
import numpy as np
import mlx_tests
try:
import torch
import torch.nn.functional as F
has_torch = True
except ImportError as e:
has_torch = False
class TestConv(mlx_tests.MLXTestCase):
def test_numpy_conv(self):
for dtype in (
"float16",
"float32",
):
np_dtype = getattr(np, dtype)
for M, N, mode in (
(1, 1, "full"),
(25, 5, "full"),
(24, 5, "same"),
(24, 4, "same"),
(24, 4, "valid"),
(4, 24, "full"),
(5, 25, "same"),
(4, 25, "valid"),
):
with self.subTest(dtype=dtype, M=M, N=N, mode=mode):
atol = 1e-6 if dtype == "float32" else 1e-5
a_np = np.random.rand(M).astype(np_dtype)
v_np = np.random.rand(N).astype(np_dtype)
a_mx = mx.array(a_np)
v_mx = mx.array(v_np)
c_np = np.convolve(a_np, v_np, mode=mode)
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_1D(self):
def run_conv1D(
N,
C,
O,
iH,
kH,
stride,
padding,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 2, 1)), (in_np, wt_np)
)
out_mx = mx.conv1d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv1d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
# Strided inputs tests
for tpose_in, tpose_wt in (
((0, 2, 1), (0, 1, 2)),
((0, 2, 1), (0, 2, 1)),
):
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_mx_t = mx.transpose(in_mx, tpose_in)
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
out_mx = mx.conv1d(in_mx_t, wt_mx_t)
in_pt, wt_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
(in_np.transpose(tpose_in), wt_np.transpose(tpose_wt)),
)
out_pt = torch.conv1d(in_pt, wt_pt)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_1D_grad(self):
def run_conv1D_grad(
N,
C,
O,
iH,
kH,
stride,
padding,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
ct_np = np.random.normal(0, 1.0 / C, (N, oH, O)).astype(np_dtype)
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
in_pt, wt_pt, ct_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
(in_np, wt_np, ct_np),
)
def f(a, b):
return mx.conv1d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
)
pt_grad_in = F.grad.conv1d_input(
in_pt.shape,
wt_pt,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_wt = F.grad.conv1d_weight(
in_pt,
wt_pt.shape,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_in = torch.transpose(pt_grad_in, 2, 1).numpy()
pt_grad_wt = torch.transpose(pt_grad_wt, 2, 1).numpy()
mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv1D_grad(N, C, O, iH, kH, stride, padding, dtype=dtype)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_2D(self):
def run_conv2D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
(in_np, wt_np),
)
out_mx = mx.conv2d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv2d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_2D_grad(self):
def run_conv2D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C)
oH = 1 + (
(iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0]
)
oW = 1 + (
(iW + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1]
)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
ct_np = np.random.normal(0.0, scale, (N, oH, oW, O)).astype(np_dtype)
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
in_pt, wt_pt, ct_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
(in_np, wt_np, ct_np),
)
def f(a, b):
return mx.conv2d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
)
pt_grad_in = F.grad.conv1d_input(
in_pt.shape,
wt_pt,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_wt = F.grad.conv1d_weight(
in_pt,
wt_pt.shape,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 1)).numpy()
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 1)).numpy()
mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
if __name__ == "__main__":
unittest.main()

157
python/tests/test_load.py Normal file
View File

@@ -0,0 +1,157 @@
import unittest
import os
import mlx.core as mx
import numpy as np
import tempfile
import mlx_tests
class TestLoad(mlx_tests.MLXTestCase):
dtypes = [
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float32",
"float16",
"complex64",
]
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_save_and_load(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes:
with self.subTest(dtype=dt):
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
with self.subTest(shape=shape):
save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy")
save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy")
save_arr = np.random.uniform(0.0, 32.0, size=shape)
save_arr_npy = save_arr.astype(getattr(np, dt))
save_arr_mlx = mx.array(save_arr_npy)
mx.save(save_file_mlx, save_arr_mlx)
np.save(save_file_npy, save_arr_npy)
# Load array saved by mlx as mlx array
load_arr_mlx_mlx = mx.load(save_file_mlx)
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
# Load array saved by numpy as mlx array
load_arr_npy_mlx = mx.load(save_file_npy)
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
# Load array saved by mlx as numpy array
load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
def test_save_and_load_fs(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes:
with self.subTest(dtype=dt):
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
with self.subTest(shape=shape):
save_file_mlx = os.path.join(
self.test_dir, f"mlx_{dt}_{i}_fs.npy"
)
save_file_npy = os.path.join(
self.test_dir, f"npy_{dt}_{i}_fs.npy"
)
save_arr = np.random.uniform(0.0, 32.0, size=shape)
save_arr_npy = save_arr.astype(getattr(np, dt))
save_arr_mlx = mx.array(save_arr_npy)
with open(save_file_mlx, "wb") as f:
mx.save(f, save_arr_mlx)
np.save(save_file_npy, save_arr_npy)
# Load array saved by mlx as mlx array
with open(save_file_mlx, "rb") as f:
load_arr_mlx_mlx = mx.load(f)
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
# Load array saved by numpy as mlx array
with open(save_file_npy, "rb") as f:
load_arr_npy_mlx = mx.load(f)
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
# Load array saved by mlx as numpy array
load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
def test_savez_and_loadz(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes:
with self.subTest(dtype=dt):
shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]
save_file_mlx_uncomp = os.path.join(
self.test_dir, f"mlx_{dt}_uncomp.npz"
)
save_file_npy_uncomp = os.path.join(
self.test_dir, f"npy_{dt}_uncomp.npz"
)
save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz")
save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz")
# Make dictionary of multiple
save_arrs_npy = {
f"save_arr_{i}": np.random.uniform(
0.0, 32.0, size=shapes[i]
).astype(getattr(np, dt))
for i in range(len(shapes))
}
save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}
# Save as npz files
np.savez(save_file_npy_uncomp, **save_arrs_npy)
mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)
np.savez_compressed(save_file_npy_comp, **save_arrs_npy)
mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)
for save_file_npy, save_file_mlx in (
(save_file_npy_uncomp, save_file_mlx_uncomp),
(save_file_npy_comp, save_file_mlx_comp),
):
# Load array saved by mlx as mlx array
load_arr_mlx_mlx = mx.load(save_file_mlx)
for k, v in load_arr_mlx_mlx.items():
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
# Load arrays saved by numpy as mlx arrays
load_arr_npy_mlx = mx.load(save_file_npy)
for k, v in load_arr_npy_mlx.items():
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
# Load array saved by mlx as numpy array
load_arr_mlx_npy = np.load(save_file_mlx)
for k, v in load_arr_mlx_npy.items():
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
if __name__ == "__main__":
unittest.main()

231
python/tests/test_nn.py Normal file
View File

@@ -0,0 +1,231 @@
import unittest
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten
import numpy as np
import os
import tempfile
import mlx_tests
class TestNN(mlx_tests.MLXTestCase):
def test_linear(self):
inputs = mx.zeros((10, 4))
layer = nn.Linear(input_dims=4, output_dims=8)
outputs = layer(inputs)
self.assertEqual(tuple(outputs.shape), (10, 8))
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
losses = nn.losses.cross_entropy(logits, targets)
self.assertTrue(mx.array_equal(losses, mx.zeros((2,))))
def test_gelu(self):
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
# From: jax.nn.gelu(np.array(inputs), approximate=False)
expected = np.array(
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
)
out = nn.GELU()(mx.array(inputs))
self.assertTrue(np.allclose(out, expected))
# Crudely check the approximations
x = mx.arange(-6.0, 6.0, 12 / 100)
y = nn.gelu(x)
y_hat1 = nn.gelu_approx(x)
y_hat2 = nn.gelu_fast_approx(x)
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
def test_group_norm(self):
x = mx.arange(100, dtype=mx.float32)
x = x.reshape(1, 10, 10, 1)
x = mx.broadcast_to(x, (2, 10, 10, 4))
x = mx.concatenate([x, 0.5 * x], axis=-1)
# Group norm in groups last mode
g = nn.GroupNorm(2, 8)
y = g(x)
means = y.reshape(2, -1, 2).mean(axis=1)
var = y.reshape(2, -1, 2).var(axis=1)
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
g.weight = g.weight * 2
g.bias = g.bias + 3
y = g(x)
means = y.reshape(2, -1, 2).mean(axis=1)
var = y.reshape(2, -1, 2).var(axis=1)
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
# Group norm in groups first mode
g = nn.GroupNorm(2, 8, pytorch_compatible=True)
y = g(x)
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
g.weight = g.weight * 2
g.bias = g.bias + 3
y = g(x)
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
def test_conv1d(self):
N = 5
L = 12
ks = 3
C_in = 2
C_out = 4
x = mx.ones((N, L, C_in))
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
c.weight = mx.ones_like(c.weight)
y = c(x)
self.assertEqual(y.shape, [N, L - ks + 1, C_out])
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
y = c(x)
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out])
self.assertTrue("bias" in c.parameters())
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
self.assertTrue("bias" not in c.parameters())
def test_conv2d(self):
x = mx.ones((4, 8, 8, 3))
c = nn.Conv2d(3, 1, 8)
y = c(x)
self.assertEqual(y.shape, [4, 1, 1, 1])
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
y = c(x)
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
# 3x3 conv no padding stride 1
c = nn.Conv2d(3, 8, 3)
y = c(x)
self.assertEqual(y.shape, [4, 6, 6, 8])
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
# 3x3 conv padding 1 stride 1
c = nn.Conv2d(3, 8, 3, padding=1)
y = c(x)
self.assertEqual(y.shape, [4, 8, 8, 8])
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
self.assertLess(
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
1e-4,
)
self.assertLess(
mx.abs(y[:, 7, 7] - c.weight[:, :-1, :-1].sum(axis=(1, 2, 3))).max(),
1e-4,
)
self.assertLess(
mx.abs(y[:, 1:7, 7] - c.weight[:, :, :-1].sum(axis=(1, 2, 3))).max(),
1e-4,
)
self.assertLess(
mx.abs(y[:, 7, 1:7] - c.weight[:, :-1, :].sum(axis=(1, 2, 3))).max(),
1e-4,
)
# 3x3 conv no padding stride 2
c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
y = c(x)
self.assertEqual(y.shape, [4, 3, 3, 8])
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
def test_sequential(self):
x = mx.ones((10, 2))
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
y = m(x)
self.assertEqual(y.shape, [10, 1])
params = m.parameters()
self.assertTrue("layers" in params)
self.assertEqual(len(params["layers"]), 3)
self.assertTrue("weight" in params["layers"][0])
self.assertEqual(len(params["layers"][1]), 0)
self.assertTrue("weight" in params["layers"][2])
m.layers[1] = nn.relu
y2 = m(x)
self.assertTrue(mx.array_equal(y, y2))
def test_module_utilities(self):
m = nn.Sequential(
nn.Sequential(nn.Linear(2, 10), nn.relu),
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
nn.Linear(10, 1),
mx.sigmoid,
)
children = m.children()
self.assertTrue(isinstance(children, dict))
self.assertEqual(len(children), 1)
self.assertTrue(isinstance(children["layers"], list))
self.assertEqual(len(children["layers"]), 4)
self.assertEqual(children["layers"][3], {})
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
self.assertEqual(len(flat_children), 3)
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
self.assertEqual(len(leaves), 4)
self.assertEqual(leaves[0][0], "layers.0.layers.0")
self.assertEqual(leaves[1][0], "layers.1.layers.0")
self.assertEqual(leaves[2][0], "layers.1.layers.1")
self.assertEqual(leaves[3][0], "layers.2")
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval()
def assert_not_training(k, m):
self.assertFalse(m.training)
m.apply_to_modules(assert_not_training)
m.train()
def assert_training(k, m):
self.assertTrue(m.training)
m.apply_to_modules(assert_training)
def test_sin_pe(self):
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
x = mx.arange(10)
y = m(x)
self.assertEqual(y.shape, [10, 16])
similarities = y @ y.T
self.assertLess(
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
)
def test_io(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
m_load = make_model()
m_load.load_weights(file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,29 @@
import unittest
import mlx.core as mx
import mlx.optimizers as opt
import mlx.utils
import mlx_tests
class TestOptimizers(mlx_tests.MLXTestCase):
def test_optimizers(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
for optim in [opt.SGD(0.1), opt.Adam(0.1)]:
update = optim.apply_gradients(grads, params)
mx.eval(update)
equal_shape = mlx.utils.tree_map(
lambda x, y: x.shape == y.shape, params, update
)
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
self.assertTrue(all_equal)
if __name__ == "__main__":
unittest.main()

192
python/tests/test_random.py Normal file
View File

@@ -0,0 +1,192 @@
import unittest
import mlx.core as mx
import mlx_tests
class TestRandom(mlx_tests.MLXTestCase):
def test_global_rng(self):
mx.random.seed(3)
a = mx.random.uniform()
b = mx.random.uniform()
mx.random.seed(3)
x = mx.random.uniform()
y = mx.random.uniform()
self.assertEqual(a.item(), x.item())
self.assertEqual(y.item(), b.item())
def test_key(self):
k1 = mx.random.key(0)
k2 = mx.random.key(0)
self.assertTrue(mx.array_equal(k1, k2))
k2 = mx.random.key(1)
self.assertFalse(mx.array_equal(k1, k2))
def test_key_split(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
self.assertFalse(mx.array_equal(k1, k2))
r1, r2 = mx.random.split(key)
self.assertTrue(mx.array_equal(k1, r1))
self.assertTrue(mx.array_equal(k2, r2))
keys = mx.random.split(key, 10)
self.assertEqual(keys.shape, [10, 2])
def test_uniform(self):
key = mx.random.key(0)
a = mx.random.uniform(key=key)
self.assertEqual(a.shape, [])
self.assertEqual(a.dtype, mx.float32)
b = mx.random.uniform(key=key)
self.assertEqual(a.item(), b.item())
a = mx.random.uniform(shape=(2, 3))
self.assertEqual(a.shape, [2, 3])
a = mx.random.uniform(shape=(1000,), low=-1, high=5)
self.assertTrue(mx.all((a > -1) < 5).item())
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
self.assertTrue(mx.all((a > -1) < 5).item())
def test_normal(self):
key = mx.random.key(0)
a = mx.random.normal(key=key)
self.assertEqual(a.shape, [])
self.assertEqual(a.dtype, mx.float32)
b = mx.random.normal(key=key)
self.assertEqual(a.item(), b.item())
a = mx.random.normal(shape=(2, 3))
self.assertEqual(a.shape, [2, 3])
## Generate in float16 or bfloat16
for t in [mx.float16, mx.bfloat16]:
a = mx.random.normal(dtype=t)
self.assertEqual(a.dtype, t)
def test_randint(self):
a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, [])
self.assertEqual(a.dtype, mx.int32)
shape = [88]
low = mx.array(3)
high = mx.array(15)
key = mx.random.key(0)
a = mx.random.randint(low, high, shape, key=key)
self.assertEqual(a.shape, shape)
self.assertEqual(a.dtype, mx.int32)
# Check using the same key yields the same value
b = mx.random.randint(low, high, shape, key=key)
self.assertListEqual(a.tolist(), b.tolist())
shape = [3, 4]
low = mx.reshape(mx.array([0] * 3), [3, 1])
high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])
a = mx.random.randint(low, high, shape)
self.assertEqual(a.shape, shape)
a = mx.random.randint(-10, 10, [1000, 1000])
self.assertTrue(mx.all(-10 <= a).item() and mx.all(a < 10).item())
a = mx.random.randint(10, -10, [1000, 1000])
self.assertTrue(mx.all(a == 10).item())
def test_bernoulli(self):
a = mx.random.bernoulli()
self.assertEqual(a.shape, [])
self.assertEqual(a.dtype, mx.bool_)
a = mx.random.bernoulli(mx.array(0.5), [5])
self.assertEqual(a.shape, [5])
a = mx.random.bernoulli(mx.array([2.0, -2.0]))
self.assertEqual(a.tolist(), [True, False])
self.assertEqual(a.shape, [2])
p = mx.array([0.1, 0.2, 0.3])
mx.reshape(p, [1, 3])
x = mx.random.bernoulli(p, [4, 3])
self.assertEqual(x.shape, [4, 3])
with self.assertRaises(ValueError):
mx.random.bernoulli(p, [2]) # Bad shape
with self.assertRaises(ValueError):
mx.random.bernoulli(0, [2]) # Bad type
def test_truncated_normal(self):
a = mx.random.truncated_normal(-2.0, 2.0)
self.assertEqual(a.size, 1)
self.assertEqual(a.dtype, mx.float32)
a = mx.random.truncated_normal(mx.array([]), mx.array([]))
self.assertEqual(a.dtype, mx.float32)
self.assertEqual(a.size, 0)
lower = mx.reshape(mx.array([-2.0, 0.0]), [1, 2])
upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])
a = mx.random.truncated_normal(lower, upper)
self.assertEqual(a.shape, [3, 2])
self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())
a = mx.random.truncated_normal(2.0, -2.0)
self.assertTrue(mx.all(a == 2.0).item())
a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])
self.assertEqual(a.shape, [542, 399])
lower = mx.array([-2.0, -1.0])
higher = mx.array([1.0, 2.0, 3.0])
with self.assertRaises(ValueError):
mx.random.truncated_normal(lower, higher) # Bad shape
def test_gumbel(self):
samples = mx.random.gumbel(shape=(100, 100))
self.assertEqual(samples.shape, [100, 100])
self.assertEqual(samples.dtype, mx.float32)
mean = 0.5772
# Std deviation of the sample mean is small (<0.02),
# so this test is pretty conservative
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
def test_categorical(self):
logits = mx.zeros((10, 20))
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
self.assertEqual(mx.random.categorical(logits, 0).shape, [20])
self.assertEqual(mx.random.categorical(logits, 1).shape, [10])
out = mx.random.categorical(logits)
self.assertEqual(out.shape, [10])
self.assertEqual(out.dtype, mx.uint32)
self.assertTrue(mx.max(out).item() < 20)
out = mx.random.categorical(logits, 0, [5, 20])
self.assertEqual(out.shape, [5, 20])
self.assertTrue(mx.max(out).item() < 10)
out = mx.random.categorical(logits, 1, num_samples=7)
self.assertEqual(out.shape, [10, 7])
out = mx.random.categorical(logits, 0, num_samples=7)
self.assertEqual(out.shape, [20, 7])
with self.assertRaises(ValueError):
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
if __name__ == "__main__":
unittest.main()

26
python/tests/test_tree.py Normal file
View File

@@ -0,0 +1,26 @@
import unittest
import mlx.core as mx
import mlx.utils
import mlx_tests
class TestTreeUtils(mlx_tests.MLXTestCase):
def test_tree_map(self):
tree = {"a": 0, "b": 1, "c": 2}
tree = mlx.utils.tree_map(lambda x: x + 1, tree)
expected_tree = {"a": 1, "b": 2, "c": 3}
self.assertEqual(tree, expected_tree)
def test_tree_flatten(self):
tree = [{"a": 1, "b": 2}, "c"]
vals = (1, 2, "c")
flat_tree = mlx.utils.tree_flatten(tree)
self.assertEqual(list(zip(*flat_tree))[1], vals)
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
if __name__ == "__main__":
unittest.main()

167
python/tests/test_vmap.py Normal file
View File

@@ -0,0 +1,167 @@
import unittest
import mlx.core as mx
import mlx_tests
class TestVmap(mlx_tests.MLXTestCase):
def test_basics(self):
# Can't vmap over scalars
with self.assertRaises(ValueError):
mx.vmap(mx.exp)(mx.array(1.0))
# Invalid input
with self.assertRaises(ValueError):
mx.vmap(mx.exp)("hello")
# Invalid axes
with self.assertRaises(ValueError):
mx.vmap(mx.exp, in_axes="hello")(mx.array([0, 1]))
with self.assertRaises(ValueError):
mx.vmap(mx.exp, in_axes=2)(mx.array([0, 1]))
with self.assertRaises(ValueError):
mx.vmap(mx.exp, out_axes="hello")(mx.array([0, 1]))
with self.assertRaises(ValueError):
mx.vmap(mx.exp, out_axes=2)(mx.array([0, 1]))
def test_unary(self):
ops = [
"abs",
"cos",
"erf",
"erfinv",
"exp",
"log",
"log1p",
"log2",
"log10",
"logical_not",
"negative",
"reciprocal",
"rsqrt",
"sigmoid",
"sign",
"sin",
"sqrt",
"square",
]
ops = ["erfinv"]
for opname in ops:
with self.subTest(op=opname):
op = getattr(mx, opname)
x = mx.arange(5)
y = mx.vmap(op)(x)
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
x = mx.arange(8).reshape(2, 4)
y = mx.vmap(op)(x)
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
y = mx.vmap(op, in_axes=1, out_axes=1)(x)
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
def test_binary(self):
ops = [
"add",
"divide",
"equal",
"greater",
"greater_equal",
"less",
"less_equal",
"logaddexp",
"maximum",
"minimum",
"multiply",
"power",
"subtract",
]
for opname in ops:
with self.subTest(op=opname):
op = getattr(mx, opname)
x = mx.random.uniform(shape=(5,))
y = mx.random.uniform(shape=(5,))
out = mx.vmap(op)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y)))
x = mx.random.uniform(shape=(2, 4))
y = mx.random.uniform(shape=(2, 4))
out = mx.vmap(op)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y)))
out = mx.vmap(op, in_axes=(0, 0), out_axes=0)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y)))
y = mx.random.uniform(shape=(4, 2))
out = mx.vmap(op, in_axes=(0, 1), out_axes=0)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y.T)))
out = mx.vmap(op, in_axes=(0, 1), out_axes=1)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y.T).T))
def test_tree(self):
def my_fun(tree):
return (tree["a"] + tree["b"][0]) * tree["b"][1]
tree = {
"a": mx.random.uniform(shape=(2, 4)),
"b": (
mx.random.uniform(shape=(2, 4)),
mx.random.uniform(shape=(2, 4)),
),
}
out = mx.vmap(my_fun)(tree)
expected = my_fun(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
tree = {
"a": mx.random.uniform(shape=(2, 4)),
"b": (
mx.random.uniform(shape=(4, 2)),
mx.random.uniform(shape=(4, 2)),
),
}
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
self.assertTrue(mx.array_equal(out, expected))
def my_fun(x, y):
return {"a": x + y, "b": x * y}
x = mx.random.uniform(shape=(2, 4))
y = mx.random.uniform(shape=(2, 4))
out = mx.vmap(my_fun, in_axes=0, out_axes=0)(x, y)
expected = my_fun(x, y)
self.assertTrue(mx.array_equal(out["a"], expected["a"]))
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes=0, out_axes=(0, 1))(x, y)
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes=0, out_axes={"a": 0, "c": 1})(x, y)
out = mx.vmap(my_fun, in_axes=0, out_axes={"a": 1, "b": 0})(x, y)
expected = my_fun(x, y)
self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
if __name__ == "__main__":
unittest.main()