mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
jagrit's commit files
This commit is contained in:
16
python/tests/mlx_tests.py
Normal file
16
python/tests/mlx_tests.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class MLXTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.default = mx.default_device()
|
||||
device = os.getenv("DEVICE", None)
|
||||
if device is not None:
|
||||
device = getattr(mx, device)
|
||||
mx.set_default_device(device)
|
||||
|
||||
def tearDown(self):
|
||||
mx.set_default_device(self.default)
|
445
python/tests/test_blas.py
Normal file
445
python/tests/test_blas.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import math
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestBlas(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
def dtypes(self):
|
||||
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||
|
||||
def __gemm_test(
|
||||
self,
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype=np.float32,
|
||||
f_np_a=lambda x: x,
|
||||
f_np_b=lambda x: x,
|
||||
f_mx_a=lambda x: x,
|
||||
f_mx_b=lambda x: x,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=np.dtype(np_dtype).name, shape_a=shape_a, shape_b=shape_b
|
||||
):
|
||||
np.random.seed(42)
|
||||
scale = max(np.sum(shape_a), 128)
|
||||
a_np = np.random.normal(0.0, 1.0 / scale, shape_a).astype(np_dtype)
|
||||
b_np = np.random.normal(0.0, 1.0 / scale, shape_b).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_np = f_np_a(a_np.astype(np.float32))
|
||||
b_np = f_np_b(b_np.astype(np.float32))
|
||||
a_mx = f_mx_a(a_mx)
|
||||
b_mx = f_mx_b(b_mx)
|
||||
|
||||
out_npy = a_np @ b_np
|
||||
out_mlx = a_mx @ b_mx
|
||||
|
||||
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
|
||||
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
|
||||
|
||||
def test_matmul_unaligned(self):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
for dtype in self.dtypes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
base_shapes = [4, 8, 16, 32, 64, 128]
|
||||
pertubations = [-2, -1, 0, 1, 2]
|
||||
|
||||
for dim in base_shapes:
|
||||
for p in pertubations:
|
||||
shape_a = (dim + p, dim + p)
|
||||
shape_b = (dim + p, dim + p)
|
||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||
|
||||
def test_matmul_shapes(self):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
shapes = [
|
||||
(1, 2, 1, 1),
|
||||
(1, 1, 2, 1),
|
||||
(3, 23, 457, 3),
|
||||
]
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
shapes += [
|
||||
(16, 768, 768, 128),
|
||||
]
|
||||
|
||||
for dtype in self.dtypes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
for B, M, N, K in shapes:
|
||||
|
||||
with self.subTest(tranpose="nn"):
|
||||
shape_a = (B, M, K)
|
||||
shape_b = (B, K, N)
|
||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||
|
||||
with self.subTest(tranpose="nt"):
|
||||
shape_a = (B, M, K)
|
||||
shape_b = (B, N, K)
|
||||
self.__gemm_test(
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype,
|
||||
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
with self.subTest(tranpose="tn"):
|
||||
shape_a = (B, K, M)
|
||||
shape_b = (B, K, N)
|
||||
self.__gemm_test(
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype,
|
||||
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
with self.subTest(tranpose="tt"):
|
||||
shape_a = (B, K, M)
|
||||
shape_b = (B, N, K)
|
||||
self.__gemm_test(
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype,
|
||||
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
def test_matmul(self):
|
||||
# Note: so far, matmul only works with floating-point types
|
||||
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
b = mx.array([[0.0, -1.0], [-3.0, 3.0]])
|
||||
|
||||
expected = [[-6.0, 5.0], [-12.0, 9.0]]
|
||||
|
||||
self.assertEqual((a @ b).tolist(), expected)
|
||||
self.assertEqual(mx.matmul(a, b).tolist(), expected)
|
||||
|
||||
# Transposed matmul
|
||||
np.random.seed(0)
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
|
||||
d_npy = np.transpose(a_npy, (1, 0)) @ b_npy
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
|
||||
d_mlx = mx.transpose(a_mlx, (1, 0)) @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
|
||||
|
||||
def test_matmul_dtypes(self):
|
||||
|
||||
for dt in self.dtypes:
|
||||
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||
getattr(np, dt)
|
||||
)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||
getattr(np, dt)
|
||||
)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
c_npy = np.matmul(a_npy, b_npy, dtype=getattr(np, dt))
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
def test_matmul_batched(self):
|
||||
np.random.seed(0)
|
||||
# Batched matmul
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ b_npy
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Batched and transposed matmul
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ np.transpose(b_npy, (0, 2, 1))
|
||||
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 2, 1))
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Batched matmul with simple broadast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ b_npy
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Both operands broadcasted
|
||||
d_npy = np.broadcast_to(b_npy, (5, 16, 16))
|
||||
d_mlx = mx.broadcast_to(b_mlx, (5, 16, 16))
|
||||
|
||||
e_npy = d_npy @ d_npy
|
||||
e_mlx = d_mlx @ d_mlx
|
||||
|
||||
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
|
||||
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
|
||||
|
||||
# Batched and transposed matmul with simple broadast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
c_npy = a_npy @ b_npy
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Test Multiheaded attention style matmul
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
a_npy = np.transpose(a_npy, (0, 2, 1, 3))
|
||||
b_npy = np.transpose(b_npy, (0, 2, 1, 3))
|
||||
a_mlx = mx.transpose(a_mlx, (0, 2, 1, 3))
|
||||
b_mlx = mx.transpose(b_mlx, (0, 2, 1, 3))
|
||||
|
||||
c_npy = a_npy @ np.transpose(b_npy, (0, 1, 3, 2))
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 1, 3, 2))
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
def __gemv_test(
|
||||
self,
|
||||
shape_mat,
|
||||
shape_vec,
|
||||
np_dtype=np.float32,
|
||||
mat_first=True,
|
||||
np_mat_f=lambda x: x,
|
||||
np_vec_f=lambda x: x,
|
||||
mlx_mat_f=lambda x: x,
|
||||
mlx_vec_f=lambda x: x,
|
||||
):
|
||||
with self.subTest(shape=shape_mat):
|
||||
np.random.seed(42)
|
||||
scale = max(np.sum(shape_mat), 32)
|
||||
mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)
|
||||
vec_npy = np.random.normal(0.0, 1.0 / scale, shape_vec).astype(np_dtype)
|
||||
|
||||
mat_mlx = mx.array(mat_npy)
|
||||
vec_mlx = mx.array(vec_npy)
|
||||
|
||||
mat_npy = np_mat_f(mat_npy)
|
||||
vec_npy = np_vec_f(vec_npy)
|
||||
mat_mlx = mlx_mat_f(mat_mlx)
|
||||
vec_mlx = mlx_vec_f(vec_mlx)
|
||||
|
||||
if mat_first:
|
||||
out_npy = mat_npy @ vec_npy
|
||||
out_mlx = mat_mlx @ vec_mlx
|
||||
else:
|
||||
out_npy = vec_npy @ mat_npy
|
||||
out_mlx = vec_mlx @ mat_mlx
|
||||
|
||||
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
|
||||
self.assertTrue(np.allclose(out_mlx, out_npy, atol=1e-5))
|
||||
|
||||
def test_matrix_vector(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Basic square matrix test
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64), shape_vec=(64, 1), np_dtype=np_dtype
|
||||
)
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64),
|
||||
shape_vec=(64, 1),
|
||||
np_dtype=np_dtype,
|
||||
mat_first=False,
|
||||
np_vec_f=lambda x: np.transpose(x, (1, 0)),
|
||||
mlx_vec_f=lambda x: mx.transpose(x, (1, 0)),
|
||||
)
|
||||
|
||||
# Vector matrix product with aligned and unaligned shapes
|
||||
for in_len_base, out_len_base in (
|
||||
(2, 2),
|
||||
(32, 32),
|
||||
(64, 64),
|
||||
(2048, 2048),
|
||||
):
|
||||
for mi in (-1, 0, 1):
|
||||
for mj in (-1, 0, 1):
|
||||
# Vec mat
|
||||
shape_mat = (in_len_base + mi, out_len_base + mj)
|
||||
shape_vec = (1, in_len_base + mi)
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
# Mat vec
|
||||
shape_mat = (out_len_base + mj, in_len_base + mi)
|
||||
shape_vec = (in_len_base + mi, 1)
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
def test_matrix_vector_batched(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Batched mat vec
|
||||
for shape_mat, shape_vec in (
|
||||
((32, 128, 64), (32, 64, 1)),
|
||||
((128, 64), (32, 64, 1)),
|
||||
((32, 128, 64), (64, 1)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
# Batched vec mat
|
||||
for shape_vec, shape_mat in (
|
||||
((32, 1, 128), (32, 128, 64)),
|
||||
((32, 1, 128), (128, 64)),
|
||||
((1, 128), (32, 128, 64)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
def test_matrix_vector_broadcast(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Different broadcasts mat vec
|
||||
for shape_mat, shape_vec in (
|
||||
((32, 64, 64), (32, 64, 1)),
|
||||
((64, 64), (32, 64, 1)),
|
||||
((32, 64, 64), (64, 1)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64),
|
||||
shape_vec=(64, 1),
|
||||
np_dtype=np_dtype,
|
||||
np_mat_f=(lambda mat_npy: np.broadcast_to(mat_npy, shape_mat)),
|
||||
np_vec_f=(lambda vec_npy: np.broadcast_to(vec_npy, shape_vec)),
|
||||
mlx_mat_f=(lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat)),
|
||||
mlx_vec_f=(lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec)),
|
||||
)
|
||||
|
||||
# Different broadcasts vec mat
|
||||
for shape_vec, shape_mat in (
|
||||
((32, 1, 64), (32, 64, 64)),
|
||||
((32, 1, 64), (64, 64)),
|
||||
((1, 64), (32, 64, 64)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64),
|
||||
shape_vec=(1, 64),
|
||||
np_dtype=np_dtype,
|
||||
mat_first=False,
|
||||
np_mat_f=lambda mat_npy: np.broadcast_to(mat_npy, shape_mat),
|
||||
np_vec_f=lambda vec_npy: np.broadcast_to(vec_npy, shape_vec),
|
||||
mlx_mat_f=lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat),
|
||||
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
|
||||
)
|
||||
|
||||
def test_matrix_vector_edgecases(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
for in_vec_len in np.arange(1, 5):
|
||||
for out_vec_len in np.arange(1, 5):
|
||||
for batch_size in np.arange(1, 5):
|
||||
with self.subTest(
|
||||
problem_shape=(batch_size, in_vec_len, out_vec_len)
|
||||
):
|
||||
# Matrix vector
|
||||
with self.subTest(transpose=False):
|
||||
a_npy = np.ones(
|
||||
(batch_size, out_vec_len, in_vec_len),
|
||||
dtype=np_dtype,
|
||||
)
|
||||
b_npy = np.ones(
|
||||
(batch_size, in_vec_len, 1), dtype=np_dtype
|
||||
)
|
||||
for i in range(batch_size):
|
||||
b_npy[i] *= i + 1.0
|
||||
|
||||
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
|
||||
c_npy = a_npy @ b_npy
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(
|
||||
list(c_npy.shape), list(c_mlx.shape)
|
||||
)
|
||||
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
||||
|
||||
# Vector matrix
|
||||
with self.subTest(transpose=True):
|
||||
a_npy = np.ones(
|
||||
(batch_size, out_vec_len, in_vec_len),
|
||||
dtype=np_dtype,
|
||||
)
|
||||
b_npy = np.ones(
|
||||
(batch_size, 1, out_vec_len), dtype=np_dtype
|
||||
)
|
||||
for i in range(batch_size):
|
||||
b_npy[i] *= i + 1.0
|
||||
|
||||
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
|
||||
c_npy = b_npy @ a_npy
|
||||
c_mlx = b_mlx @ a_mlx
|
||||
|
||||
self.assertListEqual(
|
||||
list(c_npy.shape), list(c_mlx.shape)
|
||||
)
|
||||
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
445
python/tests/test_conv.py
Normal file
445
python/tests/test_conv.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import math
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
class TestConv(mlx_tests.MLXTestCase):
|
||||
def test_numpy_conv(self):
|
||||
for dtype in (
|
||||
"float16",
|
||||
"float32",
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
for M, N, mode in (
|
||||
(1, 1, "full"),
|
||||
(25, 5, "full"),
|
||||
(24, 5, "same"),
|
||||
(24, 4, "same"),
|
||||
(24, 4, "valid"),
|
||||
(4, 24, "full"),
|
||||
(5, 25, "same"),
|
||||
(4, 25, "valid"),
|
||||
):
|
||||
with self.subTest(dtype=dtype, M=M, N=N, mode=mode):
|
||||
atol = 1e-6 if dtype == "float32" else 1e-5
|
||||
a_np = np.random.rand(M).astype(np_dtype)
|
||||
v_np = np.random.rand(N).astype(np_dtype)
|
||||
a_mx = mx.array(a_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
c_np = np.convolve(a_np, v_np, mode=mode)
|
||||
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
|
||||
|
||||
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_1D(self):
|
||||
def run_conv1D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 2, 1)), (in_np, wt_np)
|
||||
)
|
||||
|
||||
out_mx = mx.conv1d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv1d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
# Strided inputs tests
|
||||
for tpose_in, tpose_wt in (
|
||||
((0, 2, 1), (0, 1, 2)),
|
||||
((0, 2, 1), (0, 2, 1)),
|
||||
):
|
||||
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
|
||||
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_mx_t = mx.transpose(in_mx, tpose_in)
|
||||
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
|
||||
out_mx = mx.conv1d(in_mx_t, wt_mx_t)
|
||||
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
|
||||
(in_np.transpose(tpose_in), wt_np.transpose(tpose_wt)),
|
||||
)
|
||||
|
||||
out_pt = torch.conv1d(in_pt, wt_pt)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_1D_grad(self):
|
||||
def run_conv1D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
|
||||
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
ct_np = np.random.normal(0, 1.0 / C, (N, oH, O)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||
in_pt, wt_pt, ct_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
|
||||
(in_np, wt_np, ct_np),
|
||||
)
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv1d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv1d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_in = torch.transpose(pt_grad_in, 2, 1).numpy()
|
||||
pt_grad_wt = torch.transpose(pt_grad_wt, 2, 1).numpy()
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv1D_grad(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_2D(self):
|
||||
def run_conv2D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
out_mx = mx.conv2d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv2d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_2D_grad(self):
|
||||
def run_conv2D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
|
||||
oH = 1 + (
|
||||
(iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0]
|
||||
)
|
||||
oW = 1 + (
|
||||
(iW + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1]
|
||||
)
|
||||
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
ct_np = np.random.normal(0.0, scale, (N, oH, oW, O)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||
in_pt, wt_pt, ct_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
||||
(in_np, wt_np, ct_np),
|
||||
)
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv2d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv1d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 1)).numpy()
|
||||
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 1)).numpy()
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
157
python/tests/test_load.py
Normal file
157
python/tests/test_load.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import unittest
|
||||
import os
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import tempfile
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestLoad(mlx_tests.MLXTestCase):
|
||||
dtypes = [
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"float32",
|
||||
"float16",
|
||||
"complex64",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_and_load(self):
|
||||
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
with self.subTest(shape=shape):
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy")
|
||||
save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy")
|
||||
|
||||
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
||||
save_arr_npy = save_arr.astype(getattr(np, dt))
|
||||
save_arr_mlx = mx.array(save_arr_npy)
|
||||
|
||||
mx.save(save_file_mlx, save_arr_mlx)
|
||||
np.save(save_file_npy, save_arr_npy)
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by numpy as mlx array
|
||||
load_arr_npy_mlx = mx.load(save_file_npy)
|
||||
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by mlx as numpy array
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_save_and_load_fs(self):
|
||||
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
with self.subTest(shape=shape):
|
||||
save_file_mlx = os.path.join(
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.npy"
|
||||
)
|
||||
save_file_npy = os.path.join(
|
||||
self.test_dir, f"npy_{dt}_{i}_fs.npy"
|
||||
)
|
||||
|
||||
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
||||
save_arr_npy = save_arr.astype(getattr(np, dt))
|
||||
save_arr_mlx = mx.array(save_arr_npy)
|
||||
|
||||
with open(save_file_mlx, "wb") as f:
|
||||
mx.save(f, save_arr_mlx)
|
||||
|
||||
np.save(save_file_npy, save_arr_npy)
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
with open(save_file_mlx, "rb") as f:
|
||||
load_arr_mlx_mlx = mx.load(f)
|
||||
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by numpy as mlx array
|
||||
with open(save_file_npy, "rb") as f:
|
||||
load_arr_npy_mlx = mx.load(f)
|
||||
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by mlx as numpy array
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_savez_and_loadz(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]
|
||||
save_file_mlx_uncomp = os.path.join(
|
||||
self.test_dir, f"mlx_{dt}_uncomp.npz"
|
||||
)
|
||||
save_file_npy_uncomp = os.path.join(
|
||||
self.test_dir, f"npy_{dt}_uncomp.npz"
|
||||
)
|
||||
save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz")
|
||||
save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz")
|
||||
|
||||
# Make dictionary of multiple
|
||||
save_arrs_npy = {
|
||||
f"save_arr_{i}": np.random.uniform(
|
||||
0.0, 32.0, size=shapes[i]
|
||||
).astype(getattr(np, dt))
|
||||
for i in range(len(shapes))
|
||||
}
|
||||
save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}
|
||||
|
||||
# Save as npz files
|
||||
np.savez(save_file_npy_uncomp, **save_arrs_npy)
|
||||
mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)
|
||||
np.savez_compressed(save_file_npy_comp, **save_arrs_npy)
|
||||
mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)
|
||||
|
||||
for save_file_npy, save_file_mlx in (
|
||||
(save_file_npy_uncomp, save_file_mlx_uncomp),
|
||||
(save_file_npy_comp, save_file_mlx_comp),
|
||||
):
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||
for k, v in load_arr_mlx_mlx.items():
|
||||
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
||||
|
||||
# Load arrays saved by numpy as mlx arrays
|
||||
load_arr_npy_mlx = mx.load(save_file_npy)
|
||||
for k, v in load_arr_npy_mlx.items():
|
||||
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
||||
|
||||
# Load array saved by mlx as numpy array
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
for k, v in load_arr_mlx_npy.items():
|
||||
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
231
python/tests/test_nn.py
Normal file
231
python/tests/test_nn.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
import numpy as np
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestNN(mlx_tests.MLXTestCase):
|
||||
def test_linear(self):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Linear(input_dims=4, output_dims=8)
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 8))
|
||||
|
||||
def test_cross_entropy(self):
|
||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||
targets = mx.array([0, 1])
|
||||
losses = nn.losses.cross_entropy(logits, targets)
|
||||
self.assertTrue(mx.array_equal(losses, mx.zeros((2,))))
|
||||
|
||||
def test_gelu(self):
|
||||
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
|
||||
|
||||
# From: jax.nn.gelu(np.array(inputs), approximate=False)
|
||||
expected = np.array(
|
||||
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
||||
)
|
||||
|
||||
out = nn.GELU()(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out, expected))
|
||||
|
||||
# Crudely check the approximations
|
||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||
y = nn.gelu(x)
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||
|
||||
def test_group_norm(self):
|
||||
x = mx.arange(100, dtype=mx.float32)
|
||||
x = x.reshape(1, 10, 10, 1)
|
||||
x = mx.broadcast_to(x, (2, 10, 10, 4))
|
||||
x = mx.concatenate([x, 0.5 * x], axis=-1)
|
||||
|
||||
# Group norm in groups last mode
|
||||
g = nn.GroupNorm(2, 8)
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2).mean(axis=1)
|
||||
var = y.reshape(2, -1, 2).var(axis=1)
|
||||
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
||||
g.weight = g.weight * 2
|
||||
g.bias = g.bias + 3
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2).mean(axis=1)
|
||||
var = y.reshape(2, -1, 2).var(axis=1)
|
||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||
|
||||
# Group norm in groups first mode
|
||||
g = nn.GroupNorm(2, 8, pytorch_compatible=True)
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
||||
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
||||
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
||||
g.weight = g.weight * 2
|
||||
g.bias = g.bias + 3
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
||||
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||
|
||||
def test_conv1d(self):
|
||||
N = 5
|
||||
L = 12
|
||||
ks = 3
|
||||
C_in = 2
|
||||
C_out = 4
|
||||
x = mx.ones((N, L, C_in))
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
|
||||
c.weight = mx.ones_like(c.weight)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [N, L - ks + 1, C_out])
|
||||
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
|
||||
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out])
|
||||
self.assertTrue("bias" in c.parameters())
|
||||
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
|
||||
self.assertTrue("bias" not in c.parameters())
|
||||
|
||||
def test_conv2d(self):
|
||||
x = mx.ones((4, 8, 8, 3))
|
||||
c = nn.Conv2d(3, 1, 8)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 1, 1, 1])
|
||||
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
|
||||
y = c(x)
|
||||
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
|
||||
|
||||
# 3x3 conv no padding stride 1
|
||||
c = nn.Conv2d(3, 8, 3)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 6, 6, 8])
|
||||
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
|
||||
# 3x3 conv padding 1 stride 1
|
||||
c = nn.Conv2d(3, 8, 3, padding=1)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 8, 8, 8])
|
||||
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 7, 7] - c.weight[:, :-1, :-1].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 1:7, 7] - c.weight[:, :, :-1].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 7, 1:7] - c.weight[:, :-1, :].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
|
||||
# 3x3 conv no padding stride 2
|
||||
c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 3, 3, 8])
|
||||
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
|
||||
def test_sequential(self):
|
||||
x = mx.ones((10, 2))
|
||||
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
|
||||
y = m(x)
|
||||
self.assertEqual(y.shape, [10, 1])
|
||||
params = m.parameters()
|
||||
self.assertTrue("layers" in params)
|
||||
self.assertEqual(len(params["layers"]), 3)
|
||||
self.assertTrue("weight" in params["layers"][0])
|
||||
self.assertEqual(len(params["layers"][1]), 0)
|
||||
self.assertTrue("weight" in params["layers"][2])
|
||||
|
||||
m.layers[1] = nn.relu
|
||||
y2 = m(x)
|
||||
self.assertTrue(mx.array_equal(y, y2))
|
||||
|
||||
def test_module_utilities(self):
|
||||
m = nn.Sequential(
|
||||
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
||||
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
|
||||
nn.Linear(10, 1),
|
||||
mx.sigmoid,
|
||||
)
|
||||
|
||||
children = m.children()
|
||||
self.assertTrue(isinstance(children, dict))
|
||||
self.assertEqual(len(children), 1)
|
||||
self.assertTrue(isinstance(children["layers"], list))
|
||||
self.assertEqual(len(children["layers"]), 4)
|
||||
self.assertEqual(children["layers"][3], {})
|
||||
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(flat_children), 3)
|
||||
|
||||
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(leaves), 4)
|
||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||
self.assertEqual(leaves[3][0], "layers.2")
|
||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||
|
||||
m.eval()
|
||||
|
||||
def assert_not_training(k, m):
|
||||
self.assertFalse(m.training)
|
||||
|
||||
m.apply_to_modules(assert_not_training)
|
||||
|
||||
m.train()
|
||||
|
||||
def assert_training(k, m):
|
||||
self.assertTrue(m.training)
|
||||
|
||||
m.apply_to_modules(assert_training)
|
||||
|
||||
def test_sin_pe(self):
|
||||
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||
x = mx.arange(10)
|
||||
y = m(x)
|
||||
|
||||
self.assertEqual(y.shape, [10, 16])
|
||||
similarities = y @ y.T
|
||||
self.assertLess(
|
||||
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
|
||||
)
|
||||
|
||||
def test_io(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
file = os.path.join(tdir.name, "model.npz")
|
||||
m.save_weights(file)
|
||||
m_load = make_model()
|
||||
m_load.load_weights(file)
|
||||
tdir.cleanup()
|
||||
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
29
python/tests/test_optimizers.py
Normal file
29
python/tests/test_optimizers.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
def test_optimizers(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for optim in [opt.SGD(0.1), opt.Adam(0.1)]:
|
||||
update = optim.apply_gradients(grads, params)
|
||||
mx.eval(update)
|
||||
equal_shape = mlx.utils.tree_map(
|
||||
lambda x, y: x.shape == y.shape, params, update
|
||||
)
|
||||
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||
self.assertTrue(all_equal)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
192
python/tests/test_random.py
Normal file
192
python/tests/test_random.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestRandom(mlx_tests.MLXTestCase):
|
||||
def test_global_rng(self):
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform()
|
||||
b = mx.random.uniform()
|
||||
|
||||
mx.random.seed(3)
|
||||
x = mx.random.uniform()
|
||||
y = mx.random.uniform()
|
||||
|
||||
self.assertEqual(a.item(), x.item())
|
||||
self.assertEqual(y.item(), b.item())
|
||||
|
||||
def test_key(self):
|
||||
k1 = mx.random.key(0)
|
||||
k2 = mx.random.key(0)
|
||||
self.assertTrue(mx.array_equal(k1, k2))
|
||||
|
||||
k2 = mx.random.key(1)
|
||||
self.assertFalse(mx.array_equal(k1, k2))
|
||||
|
||||
def test_key_split(self):
|
||||
key = mx.random.key(0)
|
||||
|
||||
k1, k2 = mx.random.split(key)
|
||||
self.assertFalse(mx.array_equal(k1, k2))
|
||||
|
||||
r1, r2 = mx.random.split(key)
|
||||
self.assertTrue(mx.array_equal(k1, r1))
|
||||
self.assertTrue(mx.array_equal(k2, r2))
|
||||
|
||||
keys = mx.random.split(key, 10)
|
||||
self.assertEqual(keys.shape, [10, 2])
|
||||
|
||||
def test_uniform(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.uniform(key=key)
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
b = mx.random.uniform(key=key)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
a = mx.random.uniform(shape=(2, 3))
|
||||
self.assertEqual(a.shape, [2, 3])
|
||||
|
||||
a = mx.random.uniform(shape=(1000,), low=-1, high=5)
|
||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||
|
||||
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
|
||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||
|
||||
def test_normal(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.normal(key=key)
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
b = mx.random.normal(key=key)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
a = mx.random.normal(shape=(2, 3))
|
||||
self.assertEqual(a.shape, [2, 3])
|
||||
|
||||
## Generate in float16 or bfloat16
|
||||
for t in [mx.float16, mx.bfloat16]:
|
||||
a = mx.random.normal(dtype=t)
|
||||
self.assertEqual(a.dtype, t)
|
||||
|
||||
def test_randint(self):
|
||||
a = mx.random.randint(0, 1, [])
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
shape = [88]
|
||||
low = mx.array(3)
|
||||
high = mx.array(15)
|
||||
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.randint(low, high, shape, key=key)
|
||||
self.assertEqual(a.shape, shape)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
# Check using the same key yields the same value
|
||||
b = mx.random.randint(low, high, shape, key=key)
|
||||
self.assertListEqual(a.tolist(), b.tolist())
|
||||
|
||||
shape = [3, 4]
|
||||
low = mx.reshape(mx.array([0] * 3), [3, 1])
|
||||
high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])
|
||||
|
||||
a = mx.random.randint(low, high, shape)
|
||||
self.assertEqual(a.shape, shape)
|
||||
|
||||
a = mx.random.randint(-10, 10, [1000, 1000])
|
||||
self.assertTrue(mx.all(-10 <= a).item() and mx.all(a < 10).item())
|
||||
|
||||
a = mx.random.randint(10, -10, [1000, 1000])
|
||||
self.assertTrue(mx.all(a == 10).item())
|
||||
|
||||
def test_bernoulli(self):
|
||||
a = mx.random.bernoulli()
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.bool_)
|
||||
|
||||
a = mx.random.bernoulli(mx.array(0.5), [5])
|
||||
self.assertEqual(a.shape, [5])
|
||||
|
||||
a = mx.random.bernoulli(mx.array([2.0, -2.0]))
|
||||
self.assertEqual(a.tolist(), [True, False])
|
||||
self.assertEqual(a.shape, [2])
|
||||
|
||||
p = mx.array([0.1, 0.2, 0.3])
|
||||
mx.reshape(p, [1, 3])
|
||||
x = mx.random.bernoulli(p, [4, 3])
|
||||
self.assertEqual(x.shape, [4, 3])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.bernoulli(p, [2]) # Bad shape
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.bernoulli(0, [2]) # Bad type
|
||||
|
||||
def test_truncated_normal(self):
|
||||
a = mx.random.truncated_normal(-2.0, 2.0)
|
||||
self.assertEqual(a.size, 1)
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
a = mx.random.truncated_normal(mx.array([]), mx.array([]))
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
self.assertEqual(a.size, 0)
|
||||
|
||||
lower = mx.reshape(mx.array([-2.0, 0.0]), [1, 2])
|
||||
upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])
|
||||
a = mx.random.truncated_normal(lower, upper)
|
||||
|
||||
self.assertEqual(a.shape, [3, 2])
|
||||
self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())
|
||||
|
||||
a = mx.random.truncated_normal(2.0, -2.0)
|
||||
self.assertTrue(mx.all(a == 2.0).item())
|
||||
|
||||
a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])
|
||||
self.assertEqual(a.shape, [542, 399])
|
||||
|
||||
lower = mx.array([-2.0, -1.0])
|
||||
higher = mx.array([1.0, 2.0, 3.0])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.truncated_normal(lower, higher) # Bad shape
|
||||
|
||||
def test_gumbel(self):
|
||||
samples = mx.random.gumbel(shape=(100, 100))
|
||||
self.assertEqual(samples.shape, [100, 100])
|
||||
self.assertEqual(samples.dtype, mx.float32)
|
||||
mean = 0.5772
|
||||
# Std deviation of the sample mean is small (<0.02),
|
||||
# so this test is pretty conservative
|
||||
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
|
||||
|
||||
def test_categorical(self):
|
||||
logits = mx.zeros((10, 20))
|
||||
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
|
||||
self.assertEqual(mx.random.categorical(logits, 0).shape, [20])
|
||||
self.assertEqual(mx.random.categorical(logits, 1).shape, [10])
|
||||
|
||||
out = mx.random.categorical(logits)
|
||||
self.assertEqual(out.shape, [10])
|
||||
self.assertEqual(out.dtype, mx.uint32)
|
||||
self.assertTrue(mx.max(out).item() < 20)
|
||||
|
||||
out = mx.random.categorical(logits, 0, [5, 20])
|
||||
self.assertEqual(out.shape, [5, 20])
|
||||
self.assertTrue(mx.max(out).item() < 10)
|
||||
|
||||
out = mx.random.categorical(logits, 1, num_samples=7)
|
||||
self.assertEqual(out.shape, [10, 7])
|
||||
out = mx.random.categorical(logits, 0, num_samples=7)
|
||||
self.assertEqual(out.shape, [20, 7])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
26
python/tests/test_tree.py
Normal file
26
python/tests/test_tree.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.utils
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
def test_tree_map(self):
|
||||
tree = {"a": 0, "b": 1, "c": 2}
|
||||
tree = mlx.utils.tree_map(lambda x: x + 1, tree)
|
||||
|
||||
expected_tree = {"a": 1, "b": 2, "c": 3}
|
||||
self.assertEqual(tree, expected_tree)
|
||||
|
||||
def test_tree_flatten(self):
|
||||
tree = [{"a": 1, "b": 2}, "c"]
|
||||
vals = (1, 2, "c")
|
||||
flat_tree = mlx.utils.tree_flatten(tree)
|
||||
self.assertEqual(list(zip(*flat_tree))[1], vals)
|
||||
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
167
python/tests/test_vmap.py
Normal file
167
python/tests/test_vmap.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestVmap(mlx_tests.MLXTestCase):
|
||||
def test_basics(self):
|
||||
# Can't vmap over scalars
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp)(mx.array(1.0))
|
||||
|
||||
# Invalid input
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp)("hello")
|
||||
|
||||
# Invalid axes
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, in_axes="hello")(mx.array([0, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, in_axes=2)(mx.array([0, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, out_axes="hello")(mx.array([0, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, out_axes=2)(mx.array([0, 1]))
|
||||
|
||||
def test_unary(self):
|
||||
ops = [
|
||||
"abs",
|
||||
"cos",
|
||||
"erf",
|
||||
"erfinv",
|
||||
"exp",
|
||||
"log",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log10",
|
||||
"logical_not",
|
||||
"negative",
|
||||
"reciprocal",
|
||||
"rsqrt",
|
||||
"sigmoid",
|
||||
"sign",
|
||||
"sin",
|
||||
"sqrt",
|
||||
"square",
|
||||
]
|
||||
ops = ["erfinv"]
|
||||
for opname in ops:
|
||||
with self.subTest(op=opname):
|
||||
op = getattr(mx, opname)
|
||||
x = mx.arange(5)
|
||||
y = mx.vmap(op)(x)
|
||||
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||
|
||||
x = mx.arange(8).reshape(2, 4)
|
||||
y = mx.vmap(op)(x)
|
||||
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||
|
||||
y = mx.vmap(op, in_axes=1, out_axes=1)(x)
|
||||
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||
|
||||
def test_binary(self):
|
||||
ops = [
|
||||
"add",
|
||||
"divide",
|
||||
"equal",
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"less",
|
||||
"less_equal",
|
||||
"logaddexp",
|
||||
"maximum",
|
||||
"minimum",
|
||||
"multiply",
|
||||
"power",
|
||||
"subtract",
|
||||
]
|
||||
for opname in ops:
|
||||
with self.subTest(op=opname):
|
||||
op = getattr(mx, opname)
|
||||
x = mx.random.uniform(shape=(5,))
|
||||
y = mx.random.uniform(shape=(5,))
|
||||
out = mx.vmap(op)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||
|
||||
x = mx.random.uniform(shape=(2, 4))
|
||||
y = mx.random.uniform(shape=(2, 4))
|
||||
out = mx.vmap(op)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||
|
||||
out = mx.vmap(op, in_axes=(0, 0), out_axes=0)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||
|
||||
y = mx.random.uniform(shape=(4, 2))
|
||||
out = mx.vmap(op, in_axes=(0, 1), out_axes=0)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y.T)))
|
||||
|
||||
out = mx.vmap(op, in_axes=(0, 1), out_axes=1)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y.T).T))
|
||||
|
||||
def test_tree(self):
|
||||
def my_fun(tree):
|
||||
return (tree["a"] + tree["b"][0]) * tree["b"][1]
|
||||
|
||||
tree = {
|
||||
"a": mx.random.uniform(shape=(2, 4)),
|
||||
"b": (
|
||||
mx.random.uniform(shape=(2, 4)),
|
||||
mx.random.uniform(shape=(2, 4)),
|
||||
),
|
||||
}
|
||||
out = mx.vmap(my_fun)(tree)
|
||||
expected = my_fun(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
tree = {
|
||||
"a": mx.random.uniform(shape=(2, 4)),
|
||||
"b": (
|
||||
mx.random.uniform(shape=(4, 2)),
|
||||
mx.random.uniform(shape=(4, 2)),
|
||||
),
|
||||
}
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
|
||||
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def my_fun(x, y):
|
||||
return {"a": x + y, "b": x * y}
|
||||
|
||||
x = mx.random.uniform(shape=(2, 4))
|
||||
y = mx.random.uniform(shape=(2, 4))
|
||||
out = mx.vmap(my_fun, in_axes=0, out_axes=0)(x, y)
|
||||
expected = my_fun(x, y)
|
||||
self.assertTrue(mx.array_equal(out["a"], expected["a"]))
|
||||
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes=0, out_axes=(0, 1))(x, y)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes=0, out_axes={"a": 0, "c": 1})(x, y)
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=0, out_axes={"a": 1, "b": 0})(x, y)
|
||||
expected = my_fun(x, y)
|
||||
self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
|
||||
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user