This commit is contained in:
Luca Arnaboldi
2024-02-20 09:25:39 +01:00
141 changed files with 7409 additions and 1945 deletions

View File

@@ -431,6 +431,14 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
# Half types
vals = [1.0, 2.0, 3.0, 4.0, 5.0]
x = mx.array(vals, dtype=mx.float16)
self.assertEqual(x.tolist(), vals)
x = mx.array(vals, dtype=mx.bfloat16)
self.assertEqual(x.tolist(), vals)
def test_array_np_conversion(self):
# Shape test
a = np.array([])

View File

@@ -2,6 +2,7 @@
import io
import unittest
from functools import partial
import mlx.core as mx
import mlx_tests
@@ -301,6 +302,243 @@ class TestCompile(mlx_tests.MLXTestCase):
cdfdx = mx.grad(outer)(x)
self.assertTrue(mx.allclose(dfdx, cdfdx))
def test_compile_capture(self):
# Test update captured state outside compiled function
state = {"y": mx.array(2)}
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state["y"]
return x
test_state(mx.array(1))
# Check the state is unchanged
self.assertEqual(state["y"], 2)
# Check the udpated state is used
state["y"] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture list
state = [mx.array(2)]
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture tuple of list
state = ([mx.array(2)],)
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0][0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0][0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Test state updated inside compiled function
state = {}
@partial(mx.compile, outputs=state)
def test_state(x):
state["y"] = x + 3
return mx.abs(x)
test_state(mx.array(-1))
self.assertEqual(state["y"].item(), 2)
# Test state changed inside compiled function
# triggers recompile
state = {}
@partial(mx.compile, inputs=state, outputs=state)
def test_state(x):
y = state.get("y", mx.array(0))
state["y"] = x + y
return x + 2 * y
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 1)
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 2)
def test_compile_rng(self):
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def fun():
return mx.random.uniform(shape=(10, 10))
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
def test_compile_kwargs(self):
@mx.compile
def fun(x, y, z):
return x + y + z
x = mx.array(1)
y = mx.array(2)
z = mx.array(3)
out = fun(x, y=y, z=z)
self.assertEqual(out.item(), 6)
def test_shapeless_compile(self):
y = 1
@partial(mx.compile, shapeless=True)
def fun(x):
return x + y
x = mx.array([1, 2])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
# The function is not recompiled, so the change
# to y should not be reflected in the output
y = 2
x = mx.array([1, 2, 3])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
# Type change recompiles
x = mx.array([1.0, 2.0, 3.0])
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
fun(x, y=y, z=z)
def test_shapeless_compile(self):
y = 1
@partial(mx.compile, shapeless=True)
def fun(x):
return x + y
x = mx.array([1, 2])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
# The function is not recompiled, so the change
# to y should not be reflected in the output
y = 2
x = mx.array([1, 2, 3])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
# Type change recompiles
x = mx.array([1.0, 2.0, 3.0])
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
# Dim change recompiles
x = mx.array([[1, 2, 3]])
self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]])))
def test_shapeless_compile_with_broadcasts(self):
x = mx.ones((2, 2))
y = mx.array([2, 2])
def fun(x, y):
return x * y
cfun = mx.compile(fun, shapeless=True)
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
y = mx.array([[3]])
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
def test_shapeless_compile_with_reduction(self):
# Test shapeless compile with a reduction
z = 1
@partial(mx.compile, shapeless=True)
def fun(x, y):
return x + y.sum(0, keepdims=True) + z
x = mx.ones((2, 2), mx.int32)
y = mx.ones((2, 2), mx.int32)
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4)))
x = mx.ones((3, 3), mx.int32)
y = mx.ones((3, 3), mx.int32)
z = 2
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5)))
x1 = mx.array([[1, 2], [3, 4], [5, 6]])
x2 = mx.array([[1, 2]])
def fun(x):
return x * x.sum(-1, keepdims=True)
cfun = mx.compile(fun, shapeless=True)
mx.eval(cfun(x1))
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
def fun(x, y):
return x + y
z = fun(mx.array(1.0), 1.0)
self.assertEqual(z.item(), 2.0)
z = fun(mx.array(1.0), 2.0)
self.assertEqual(z.item(), 3.0)
z = fun(mx.array(1.0), y=1.0)
self.assertEqual(z.item(), 2.0)
z = fun(mx.array(1.0), y=3.0)
self.assertEqual(z.item(), 4.0)
# Test tuple
@partial(mx.compile)
def fun(x, y=(1, 2)):
return x + y[0] + y[1]
z = fun(mx.array(1))
self.assertEqual(z.item(), 4)
z = fun(mx.array(1), (2, 2))
self.assertEqual(z.item(), 5)
z = fun(mx.array(1), (2, 1))
self.assertEqual(z.item(), 4)
# Test bool
@partial(mx.compile)
def fun(x, y):
if y:
return x + 1
else:
return x + 2
z = fun(mx.array(1), True)
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), False)
self.assertEqual(z.item(), 3)
# Test string
@partial(mx.compile)
def fun(x, y):
if y == "one":
return x + 1
else:
return x + 2
z = fun(mx.array(1), "one")
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), "two")
self.assertEqual(z.item(), 3)
if __name__ == "__main__":
unittest.main()

View File

@@ -38,6 +38,17 @@ class TestDevice(mlx_tests.MLXTestCase):
# Restore device
mx.set_default_device(device)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_device_context(self):
default = mx.default_device()
diff = mx.cpu if default == mx.gpu else mx.gpu
self.assertNotEqual(default, diff)
with mx.stream(diff):
a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2)))
mx.eval(a)
self.assertEqual(mx.default_device(), diff)
self.assertEqual(mx.default_device(), default)
def test_op_on_device(self):
x = mx.array(1.0)
y = mx.array(1.0)

View File

@@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase):
y = dfun_dx(mx.array(1.0))
self.assertEqual(y.item(), 6.0)
def test_eval_mixed(self):
x = mx.array(1) + 1 + 1
y = 0
z = "hello"
state = [x, y, z]
mx.eval(state)
self.assertEqual(x.item(), 3)
if __name__ == "__main__":
unittest.main()

158
python/tests/test_fast.py Normal file
View File

@@ -0,0 +1,158 @@
# Copyright © 2023-2024 Apple Inc.
import math
import unittest
import mlx.core as mx
import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset):
N = x.shape[1] + offset
dtype = x.dtype
half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional:
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return mx.reshape(rx, x.shape)
else:
x1 = x[..., : dims // 2]
x2 = x[..., dims // 2 : dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if dims < x.shape[-1]:
rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1)
else:
rx = mx.concatenate([rx1, rx2], axis=-1)
return rx
class TestFast(mlx_tests.MLXTestCase):
def test_rope(self):
T = 4
# Defaults: dims, dtype, base, scale, offset, traditional
defaults = (8, mx.float32, 10000.0, 1.0, 0, False)
# Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
# Test cases:
dtypes = [mx.float32, mx.float16, mx.bfloat16]
bases = [10000.0, 1000000.0]
scales = [1.0, 2.0]
offsets = [0, 3]
traditional = [True, False]
for traditional in [True, False]:
dims, dtype, _, scale, offset, _ = defaults
for base in bases:
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
dims, _, base, scale, offset, _ = defaults
for dtype in dtypes:
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
ry = rope_orig(
x.astype(mx.float32), dims, traditional, base, scale, offset
)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
if dtype != mx.float32:
self.assertLessEqual(
mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max()
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
dims, dtype, base, scale, _, _ = defaults
for offset in offsets:
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
dims, dtype, base, _, offset, _ = defaults
for scale in scales:
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_fast_transforms(self):
x = mx.random.uniform(shape=(2, 2, 8))
defaults = (8, False, 10000.0, 1.0, 0)
dims, traditional, base, scale, offset = defaults
# VJP
_, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
_, vjp_fast_out = mx.vjp(
lambda x: mx.fast.rope(
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
),
(x,),
(mx.ones_like(x),),
)
self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0]))
# JVP
_, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
_, jvp_fast_out = mx.jvp(
lambda x: mx.fast.rope(
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
),
(x,),
(mx.ones_like(x),),
)
self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0]))
# VMAP
x = mx.random.uniform(shape=(2, 2, 2, 8))
vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x)
vmap_fast_out = mx.vmap(
lambda x: mx.fast.rope(
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
)
)(x)
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
if __name__ == "__main__":
unittest.main()

View File

@@ -19,72 +19,73 @@ class TestFFT(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
def test_fft(self):
default = mx.default_device()
mx.set_default_device(mx.cpu)
def check_mx_np(op_mx, op_np, a_np, **kwargs):
out_np = op_np(a_np, **kwargs)
a_mx = mx.array(a_np)
out_mx = op_mx(a_mx, **kwargs)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
with mx.stream(mx.cpu):
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
# Check with slicing and padding
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
# Check with slicing and padding
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
# Check different axes
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
# Check different axes
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
# Check real fft
a_np = np.random.rand(100).astype(np.float32)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
# Check real fft
a_np = np.random.rand(100).astype(np.float32)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
# Check real inverse
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
mx.set_default_device(default)
# Check real inverse
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
def test_fftn(self):
default = mx.default_device()
mx.set_default_device(mx.cpu)
with mx.stream(mx.cpu):
r = np.random.randn(8, 8, 8).astype(np.float32)
i = np.random.randn(8, 8, 8).astype(np.float32)
a = r + 1j * i
r = np.random.randn(8, 8, 8).astype(np.float32)
i = np.random.randn(8, 8, 8).astype(np.float32)
a = r + 1j * i
axes = [None, (1, 2), (2, 1), (0, 2)]
shapes = [None, (10, 5), (5, 10)]
ops = [
"fft2",
"ifft2",
"rfft2",
"irfft2",
"fftn",
"ifftn",
"rfftn",
"irfftn",
]
axes = [None, (1, 2), (2, 1), (0, 2)]
shapes = [None, (10, 5), (5, 10)]
ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"]
for op, ax, s in itertools.product(ops, axes, shapes):
x = a
if op in ["rfft2", "rfftn"]:
x = r
self.check_mx_np(op, x, axes=ax, s=s)
mx.set_default_device(default)
for op, ax, s in itertools.product(ops, axes, shapes):
x = a
if op in ["rfft2", "rfftn"]:
x = r
self.check_mx_np(op, x, axes=ax, s=s)
if __name__ == "__main__":

View File

@@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
def test_save_and_load_safetensors(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
with self.assertRaises(Exception):
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors(
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
)
res = mx.load("test.safetensors", return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
for dt in self.dtypes + ["bfloat16"]:
with self.subTest(dtype=dt):
@@ -75,9 +84,11 @@ class TestLoad(mlx_tests.MLXTestCase):
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
)
save_dict = {
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
"test": (
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
)
}
with open(save_file_mlx, "wb") as f:
@@ -104,9 +115,11 @@ class TestLoad(mlx_tests.MLXTestCase):
self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
)
save_dict = {
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
"test": (
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
)
}
mx.save_gguf(save_file_mlx, save_dict)

View File

@@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# With weights, no label smoothing
weights = mx.array([1.0, 2.0, 1.0, 2.0])
expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944])
loss = nn.losses.binary_cross_entropy(
logits, targets, weights=weights, reduction="none"
)
self.assertTrue(mx.allclose(loss, expected))
def _test_probs_as_inputs():
probs = mx.array([0.5, 0.6, 0.7, 0.8])
targets = mx.array([0, 0, 1, 1])

View File

@@ -71,7 +71,7 @@ class TestBase(mlx_tests.MLXTestCase):
def test_save_safetensors_weights(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2), nn.ReLU())
m = make_model()
tdir = tempfile.TemporaryDirectory()
@@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase):
]
)
def test_module_state(self):
m = nn.Linear(10, 1)
m.state["hello"] = "world"
self.assertEqual(m.state["hello"], "world")
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):
@@ -900,6 +905,347 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_pooling(self):
# Test 1d pooling
x = mx.array(
[
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
[[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]],
]
)
expected_max_pool_output_no_padding_stride_1 = [
[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
[[15.0, 16.0, 17.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
]
expected_max_pool_output_no_padding_stride_2 = [
[[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]],
[[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]],
]
expected_max_pool_output_padding_1_stride_2 = [
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
[[12.0, 13.0, 14.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
]
expected_max_pool_output_padding_1_stride_2_kernel_3 = [
[[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]],
[[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]],
]
expected_avg_pool_output_no_padding_stride_1 = [
[
[1.5000, 2.5000, 3.5000],
[4.5000, 5.5000, 6.5000],
[7.5000, 8.5000, 9.5000],
],
[
[13.5000, 14.5000, 15.5000],
[16.5000, 17.5000, 18.5000],
[19.5000, 20.5000, 21.5000],
],
]
expected_avg_pool_output_no_padding_stride_2 = [
[[1.5000, 2.5000, 3.5000], [7.5000, 8.5000, 9.5000]],
[[13.5000, 14.5000, 15.5000], [19.5000, 20.5000, 21.5000]],
]
expected_avg_pool_output_padding_1_stride_2 = [
[
[0.0000, 0.5000, 1.0000],
[4.5000, 5.5000, 6.5000],
[4.5000, 5.0000, 5.5000],
],
[
[6.0000, 6.5000, 7.0000],
[16.5000, 17.5000, 18.5000],
[10.5000, 11.0000, 11.5000],
],
]
expected_avg_pool_output_padding_1_kernel_3 = [
[[1, 1.66667, 2.33333], [6, 7, 8]],
[[9, 9.66667, 10.3333], [18, 19, 20]],
]
self.assertTrue(
np.array_equal(
nn.MaxPool1d(kernel_size=2, stride=1, padding=0)(x),
expected_max_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool1d(kernel_size=2, stride=2, padding=0)(x),
expected_max_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool1d(kernel_size=2, stride=2, padding=1)(x),
expected_max_pool_output_padding_1_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool1d(kernel_size=3, stride=2, padding=1)(x),
expected_max_pool_output_padding_1_stride_2_kernel_3,
)
)
self.assertTrue(
np.allclose(
nn.AvgPool1d(kernel_size=2, stride=1, padding=0)(x),
expected_avg_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.allclose(
nn.AvgPool1d(kernel_size=2, stride=2, padding=0)(x),
expected_avg_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.allclose(
nn.AvgPool1d(kernel_size=2, stride=2, padding=1)(x),
expected_avg_pool_output_padding_1_stride_2,
)
)
self.assertTrue(
np.allclose(
nn.AvgPool1d(kernel_size=3, stride=2, padding=1)(x),
expected_avg_pool_output_padding_1_kernel_3,
)
)
# Test 2d pooling
x = mx.array(
[
[
[[0, 16], [1, 17], [2, 18], [3, 19]],
[[4, 20], [5, 21], [6, 22], [7, 23]],
[[8, 24], [9, 25], [10, 26], [11, 27]],
[[12, 28], [13, 29], [14, 30], [15, 31]],
]
]
)
expected_max_pool_output_no_padding_stride_1 = [
[
[[5, 21], [6, 22], [7, 23]],
[[9, 25], [10, 26], [11, 27]],
[[13, 29], [14, 30], [15, 31]],
]
]
expected_max_pool_output_no_padding_stride_2 = [
[[[5, 21], [7, 23]], [[13, 29], [15, 31]]]
]
expected_max_pool_output_padding_1 = [
[
[[0, 16], [2, 18], [3, 19]],
[[8, 24], [10, 26], [11, 27]],
[[12, 28], [14, 30], [15, 31]],
]
]
expected_mean_pool_output_no_padding_stride_1 = [
[
[[2.5000, 18.5000], [3.5000, 19.5000], [4.5000, 20.5000]],
[[6.5000, 22.5000], [7.5000, 23.5000], [8.5000, 24.5000]],
[[10.5000, 26.5000], [11.5000, 27.5000], [12.5000, 28.5000]],
]
]
expected_mean_pool_output_no_padding_stride_2 = [
[
[[2.5000, 18.5000], [4.5000, 20.5000]],
[[10.5000, 26.5000], [12.5000, 28.5000]],
]
]
expected_mean_pool_output_padding_1 = [
[
[[0.0000, 4.0000], [0.7500, 8.7500], [0.7500, 4.7500]],
[[3.0000, 11.0000], [7.5000, 23.5000], [4.5000, 12.5000]],
[[3.0000, 7.0000], [6.7500, 14.7500], [3.7500, 7.7500]],
]
]
self.assertTrue(
np.array_equal(
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)(x),
expected_max_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(x),
expected_max_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(x),
expected_max_pool_output_padding_1,
)
)
# Average pooling
self.assertTrue(
np.allclose(
nn.AvgPool2d(kernel_size=2, stride=1, padding=0)(x),
expected_mean_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(x),
expected_mean_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool2d(kernel_size=2, stride=2, padding=1)(x),
expected_mean_pool_output_padding_1,
)
)
# Test multiple batches
x = mx.array(
[
[
[[0, 1], [2, 3], [4, 5], [6, 7]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[16, 17], [18, 19], [20, 21], [22, 23]],
[[24, 25], [26, 27], [28, 29], [30, 31]],
],
[
[[32, 33], [34, 35], [36, 37], [38, 39]],
[[40, 41], [42, 43], [44, 45], [46, 47]],
[[48, 49], [50, 51], [52, 53], [54, 55]],
[[56, 57], [58, 59], [60, 61], [62, 63]],
],
]
)
expected_max_pool_output = [
[[[10.0, 11.0], [14.0, 15.0]], [[26.0, 27.0], [30.0, 31.0]]],
[[[42.0, 43.0], [46.0, 47.0]], [[58.0, 59.0], [62.0, 63.0]]],
]
expected_avg_pool_output = [
[[[2.22222, 2.66667], [5.33333, 6]], [[11.3333, 12], [20, 21]]],
[[[16.4444, 16.8889], [26.6667, 27.3333]], [[32.6667, 33.3333], [52, 53]]],
]
self.assertTrue(
np.array_equal(
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x),
expected_max_pool_output,
)
)
self.assertTrue(
np.allclose(
nn.AvgPool2d(kernel_size=3, stride=2, padding=1)(x),
expected_avg_pool_output,
)
)
# Test irregular kernel (2, 4), stride (3, 1) and padding (1, 2)
x = mx.array(
[
[
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
[[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]],
[[24, 25, 26], [27, 28, 29], [30, 31, 32], [33, 34, 35]],
[[36, 37, 38], [39, 40, 41], [42, 43, 44], [45, 46, 47]],
],
[
[[48, 49, 50], [51, 52, 53], [54, 55, 56], [57, 58, 59]],
[[60, 61, 62], [63, 64, 65], [66, 67, 68], [69, 70, 71]],
[[72, 73, 74], [75, 76, 77], [78, 79, 80], [81, 82, 83]],
[[84, 85, 86], [87, 88, 89], [90, 91, 92], [93, 94, 95]],
],
]
)
expected_irregular_max_pool_output = [
[
[
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0],
[9.0, 10.0, 11.0],
[9.0, 10.0, 11.0],
],
[
[39.0, 40.0, 41.0],
[42.0, 43.0, 44.0],
[45.0, 46.0, 47.0],
[45.0, 46.0, 47.0],
[45.0, 46.0, 47.0],
],
],
[
[
[51.0, 52.0, 53.0],
[54.0, 55.0, 56.0],
[57.0, 58.0, 59.0],
[57.0, 58.0, 59.0],
[57.0, 58.0, 59.0],
],
[
[87.0, 88.0, 89.0],
[90.0, 91.0, 92.0],
[93.0, 94.0, 95.0],
[93.0, 94.0, 95.0],
[93.0, 94.0, 95.0],
],
],
]
expected_irregular_average_pool_output = [
[
[
[0.3750, 0.6250, 0.8750],
[1.1250, 1.5000, 1.8750],
[2.2500, 2.7500, 3.2500],
[2.2500, 2.6250, 3.0000],
[1.8750, 2.1250, 2.3750],
],
[
[15.7500, 16.2500, 16.7500],
[24.7500, 25.5000, 26.2500],
[34.5000, 35.5000, 36.5000],
[27.0000, 27.7500, 28.5000],
[18.7500, 19.2500, 19.7500],
],
],
[
[
[12.3750, 12.6250, 12.8750],
[19.1250, 19.5000, 19.8750],
[26.2500, 26.7500, 27.2500],
[20.2500, 20.6250, 21.0000],
[13.8750, 14.1250, 14.3750],
],
[
[39.7500, 40.2500, 40.7500],
[60.7500, 61.5000, 62.2500],
[82.5000, 83.5000, 84.5000],
[63.0000, 63.7500, 64.5000],
[42.7500, 43.2500, 43.7500],
],
],
]
self.assertTrue(
np.array_equal(
nn.MaxPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x),
expected_irregular_max_pool_output,
)
)
self.assertTrue(
np.allclose(
nn.AvgPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x),
expected_irregular_average_pool_output,
)
)
# Test repr
self.assertEqual(
str(nn.MaxPool1d(kernel_size=3, padding=2)),
"MaxPool1d(kernel_size=(3,), stride=(3,), padding=(2,))",
)
self.assertEqual(
str(nn.AvgPool1d(kernel_size=2, stride=3)),
"AvgPool1d(kernel_size=(2,), stride=(3,), padding=(0,))",
)
self.assertEqual(
str(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
"MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))",
)
self.assertEqual(
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
import math
import os
import unittest
from itertools import permutations
@@ -274,6 +275,20 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), 1)
z = -1 % x
self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), 1)
z = -1 % -x
self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), -1)
x = mx.arange(10).astype(dt) - 5
y = x % 5
z = x % -5
self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1])
def test_comparisons(self):
a = mx.array([0.0, 1.0, 5.0])
b = mx.array([-1.0, 2.0, 5.0])
@@ -1012,6 +1027,9 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(y.tolist(), [[3, 4]])
self.assertEqual(z.tolist(), [[5, 6]])
with self.assertRaises(ValueError):
mx.split(a, 3, axis=2)
a = mx.arange(8)
x, y, z = mx.split(a, [1, 5])
self.assertEqual(x.tolist(), [0])
@@ -1318,9 +1336,7 @@ class TestOps(mlx_tests.MLXTestCase):
for d in dims:
anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d)
for n_bsx in range(d):
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape(
[size] * n_bsx
)
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx)
for _ in range(trial_mul * d):
amlx = mx.array(anp)
bmlx = mx.array(bnp)
@@ -1371,6 +1387,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue((a[:-1] < 1e-9).all())
self.assertEqual(a[-1], 1)
# Sliced inputs
y = mx.random.uniform(shape=(8, 4))
out = mx.softmax(y[:, 0:2], axis=-1)
self.assertAlmostEqual(out.sum().item(), 8.0)
def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32)
b_npy = np.random.randn(32, 32, 32)
@@ -1566,6 +1587,10 @@ class TestOps(mlx_tests.MLXTestCase):
d_np = np.take(b_mx, np.arange(kth), axis=axis)
self.assertTrue(np.all(d_np <= c_mx))
@unittest.skipIf(
os.getenv("LOW_MEMORY", None) is not None,
"This test requires a lot of memory",
)
def test_large_binary(self):
a = mx.ones([1000, 2147484], mx.int8)
b = mx.ones([2147484], mx.int8)
@@ -1677,6 +1702,8 @@ class TestOps(mlx_tests.MLXTestCase):
def test_repeat(self):
# Setup data for the tests
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
# Test repeat 0 times
self.assertCmpNumpy([data, 0], mx.repeat, np.repeat)
# Test repeat along axis 0
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
# Test repeat along axis 1
@@ -1856,6 +1883,96 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.array(np.diag(x, k=-1))
self.assertTrue(mx.array_equal(result, expected))
def test_atleast_1d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
[1, 2, 3],
[1, 2, 3, 4],
[[1], [2], [3]],
[[1, 2], [3, 4]],
[[1, 2, 3], [4, 5, 6]],
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
def test_atleast_2d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
[1, 2, 3],
[1, 2, 3, 4],
[[1], [2], [3]],
[[1, 2], [3, 4]],
[[1, 2, 3], [4, 5, 6]],
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
def test_atleast_3d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
[1, 2, 3],
[1, 2, 3, 4],
[[1], [2], [3]],
[[1, 2], [3, 4]],
[[1, 2, 3], [4, 5, 6]],
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,50 +1,215 @@
# Copyright © 2023 Apple Inc.
import inspect
import math
import unittest
from functools import partial
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
import mlx.utils
import mlx_tests
from mlx.utils import tree_flatten, tree_map
def get_all_optimizers():
classes = dict()
for name, obj in inspect.getmembers(opt):
if inspect.isclass(obj):
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
classes[name] = obj
if (
inspect.isclass(obj)
and issubclass(obj, opt.Optimizer)
and obj != opt.Optimizer
):
classes[name] = obj
return classes
def tree_equal(fn, *args):
return all(v for _, v in tree_flatten(tree_map(fn, *args)))
optimizers_dict = get_all_optimizers()
class TestOptimizers(mlx_tests.MLXTestCase):
def test_optimizer_state(self):
optim = opt.SGD(0.1)
optim.state["hello"] = "world"
self.assertEqual(optim.state["hello"], "world")
optim.state = {0: 1}
self.assertEqual(optim.state, {0: 1})
def test_optimizers(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
grads = tree_map(lambda x: mx.ones_like(x), params)
for optim_class in optimizers_dict.values():
optim = optim_class(0.1)
update = optim.apply_gradients(grads, params)
mx.eval(update)
equal_shape = mlx.utils.tree_map(
lambda x, y: x.shape == y.shape, params, update
)
equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
self.assertTrue(all_equal)
def test_types_conserved(self):
params = {"w": mx.ones((5, 5), mx.float16)}
grads = tree_map(lambda x: mx.ones_like(x), params)
for optim_class in optimizers_dict.values():
optim = optim_class(0.1)
update = optim.apply_gradients(grads, params)
self.assertEqual(update["w"].dtype, mx.float16)
def test_sgd(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Implicit init
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
optim.apply_gradients(grads, params)
self.assertTrue(
tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state)
)
def test_rmsprop(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.RMSprop(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Implicit init
alpha = 0.99
optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)
optim.apply_gradients(grads, params)
self.assertTrue(
tree_equal(
lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state
)
)
def test_adagrad(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Adagrad(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_adadelta(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.AdaDelta(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_adam(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:
optim = optimizer(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_lion(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Lion(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_adafactor(self):
x = mx.zeros((5, 5))
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
@@ -52,11 +217,129 @@ class TestOptimizers(mlx_tests.MLXTestCase):
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
def test_compiled_optimizer(self):
model = nn.Linear(10, 10)
x = mx.random.uniform(shape=(2, 10))
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
orig_params = model.parameters()
def loss(model, x):
return model(x).sum()
# Uncompiled version
def step(x):
_, grad = nn.value_and_grad(model, loss)(model, x)
optim.update(model, grad)
step(x)
uncompiled_params = model.parameters()
# Pure version
def loss(params, x):
model.update(params)
return model(x).sum()
model.update(orig_params)
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
@mx.compile
def step(params, opt_state, x):
grad = mx.grad(loss)(params, x)
optim.state = opt_state
params = optim.apply_gradients(grad, params)
return params, optim.state
optim.init(model.parameters())
pure_params, _ = step(model.parameters(), optim.state, x)
self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"]))
self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"]))
# Impure version
def loss(model, x):
return model(x).sum()
model.update(orig_params)
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
state = [model.state, optim.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x):
_, grad = nn.value_and_grad(model, loss)(model, x)
optim.update(model, grad)
step(x)
impure_params = model.parameters()
self.assertTrue(
mx.allclose(impure_params["weight"], uncompiled_params["weight"])
)
self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"]))
def test_update_lr_compiled(self):
params = {"w": mx.ones((5, 5))}
grads = tree_map(lambda x: mx.ones_like(x), params)
optim = opt.SGD(-1.0)
@partial(mx.compile, inputs=optim.state)
def update(grads):
return optim.apply_gradients(grads, params)
result = update(grads)
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0)))
optim.learning_rate = -2.0
result = update(grads)
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
class TestSchedulers(unittest.TestCase):
def test_decay_lr(self):
for optim_class in optimizers_dict.values():
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
optimizer = optim_class(learning_rate=lr_schedule)
params = {"w": mx.ones((5, 5))}
grads = tree_map(lambda x: mx.ones_like(x), params)
for it in range(10):
expected_lr = 0.1 * (0.9**it)
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
return optimizer.apply_gradients(grads, params)
def test_step_decay(self):
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
lr = lr_schedule(2500)
expected_lr = 0.1 * (0.9**2)
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_exponential_decay(self):
lr_schedule = opt.exponential_decay(1e-1, 0.99)
lr = lr_schedule(10)
expected_lr = 0.1 * (0.99**10)
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_cosine_decay(self):
lr_schedule = opt.cosine_decay(0.1, 10)
lr = lr_schedule(4)
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_compile_with_schedule(self):
lr_schedule = opt.exponential_decay(1e-1, 0.9)
optimizer = opt.SGD(learning_rate=lr_schedule)
@partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
def update():
optimizer.update({}, {})
for step in range(5):
update()
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
if __name__ == "__main__":
unittest.main()

View File

@@ -165,6 +165,70 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_non_multiples(self):
w = mx.random.normal(shape=(33, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qvm
x = mx.random.normal(shape=(1, 33))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 33))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Smaller than 8
w = mx.random.normal(shape=(3, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qvm
x = mx.random.normal(shape=(1, 3))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm
x = mx.random.normal(shape=(10, 3))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
if __name__ == "__main__":
unittest.main()

View File

@@ -80,6 +80,20 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.normal(dtype=t)
self.assertEqual(a.dtype, t)
# Generate with a given mean and standard deviation
loc = 1.0
scale = 2.0
a = mx.random.normal(shape=(3, 2), loc=loc, scale=scale, key=key)
b = scale * mx.random.normal(shape=(3, 2), key=key) + loc
self.assertTrue(mx.allclose(a, b))
a = mx.random.normal(
shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key
)
b = scale * mx.random.normal(shape=(3, 2), dtype=mx.float16, key=key) + loc
self.assertTrue(mx.allclose(a, b))
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
def test_randint(self):