mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'main' of https://github.com/ml-explore/mlx
This commit is contained in:
@@ -431,6 +431,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(vals)
|
||||
self.assertEqual(x.tolist(), vals)
|
||||
|
||||
# Half types
|
||||
vals = [1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
x = mx.array(vals, dtype=mx.float16)
|
||||
self.assertEqual(x.tolist(), vals)
|
||||
|
||||
x = mx.array(vals, dtype=mx.bfloat16)
|
||||
self.assertEqual(x.tolist(), vals)
|
||||
|
||||
def test_array_np_conversion(self):
|
||||
# Shape test
|
||||
a = np.array([])
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@@ -301,6 +302,243 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
cdfdx = mx.grad(outer)(x)
|
||||
self.assertTrue(mx.allclose(dfdx, cdfdx))
|
||||
|
||||
def test_compile_capture(self):
|
||||
# Test update captured state outside compiled function
|
||||
state = {"y": mx.array(2)}
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state["y"]
|
||||
return x
|
||||
|
||||
test_state(mx.array(1))
|
||||
# Check the state is unchanged
|
||||
self.assertEqual(state["y"], 2)
|
||||
|
||||
# Check the udpated state is used
|
||||
state["y"] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Capture list
|
||||
state = [mx.array(2)]
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state[0]
|
||||
return x
|
||||
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 3)
|
||||
state[0] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Capture tuple of list
|
||||
state = ([mx.array(2)],)
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state[0][0]
|
||||
return x
|
||||
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 3)
|
||||
state[0][0] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Test state updated inside compiled function
|
||||
state = {}
|
||||
|
||||
@partial(mx.compile, outputs=state)
|
||||
def test_state(x):
|
||||
state["y"] = x + 3
|
||||
return mx.abs(x)
|
||||
|
||||
test_state(mx.array(-1))
|
||||
self.assertEqual(state["y"].item(), 2)
|
||||
|
||||
# Test state changed inside compiled function
|
||||
# triggers recompile
|
||||
state = {}
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def test_state(x):
|
||||
y = state.get("y", mx.array(0))
|
||||
state["y"] = x + y
|
||||
return x + 2 * y
|
||||
|
||||
test_state(mx.array(1))
|
||||
self.assertEqual(state["y"].item(), 1)
|
||||
test_state(mx.array(1))
|
||||
self.assertEqual(state["y"].item(), 2)
|
||||
|
||||
def test_compile_rng(self):
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def fun():
|
||||
return mx.random.uniform(shape=(10, 10))
|
||||
|
||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||
|
||||
def test_compile_kwargs(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
x = mx.array(1)
|
||||
y = mx.array(2)
|
||||
z = mx.array(3)
|
||||
out = fun(x, y=y, z=z)
|
||||
self.assertEqual(out.item(), 6)
|
||||
|
||||
def test_shapeless_compile(self):
|
||||
y = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x):
|
||||
return x + y
|
||||
|
||||
x = mx.array([1, 2])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
||||
|
||||
# The function is not recompiled, so the change
|
||||
# to y should not be reflected in the output
|
||||
y = 2
|
||||
x = mx.array([1, 2, 3])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
||||
|
||||
# Type change recompiles
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
||||
fun(x, y=y, z=z)
|
||||
|
||||
def test_shapeless_compile(self):
|
||||
y = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x):
|
||||
return x + y
|
||||
|
||||
x = mx.array([1, 2])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
||||
|
||||
# The function is not recompiled, so the change
|
||||
# to y should not be reflected in the output
|
||||
y = 2
|
||||
x = mx.array([1, 2, 3])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
||||
|
||||
# Type change recompiles
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
||||
|
||||
# Dim change recompiles
|
||||
x = mx.array([[1, 2, 3]])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]])))
|
||||
|
||||
def test_shapeless_compile_with_broadcasts(self):
|
||||
x = mx.ones((2, 2))
|
||||
y = mx.array([2, 2])
|
||||
|
||||
def fun(x, y):
|
||||
return x * y
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
|
||||
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
|
||||
y = mx.array([[3]])
|
||||
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
|
||||
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
|
||||
|
||||
def test_shapeless_compile_with_reduction(self):
|
||||
# Test shapeless compile with a reduction
|
||||
z = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x, y):
|
||||
return x + y.sum(0, keepdims=True) + z
|
||||
|
||||
x = mx.ones((2, 2), mx.int32)
|
||||
y = mx.ones((2, 2), mx.int32)
|
||||
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4)))
|
||||
x = mx.ones((3, 3), mx.int32)
|
||||
y = mx.ones((3, 3), mx.int32)
|
||||
z = 2
|
||||
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5)))
|
||||
|
||||
x1 = mx.array([[1, 2], [3, 4], [5, 6]])
|
||||
x2 = mx.array([[1, 2]])
|
||||
|
||||
def fun(x):
|
||||
return x * x.sum(-1, keepdims=True)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
mx.eval(cfun(x1))
|
||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||
|
||||
def test_compile_with_constant(self):
|
||||
|
||||
# Test float
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
z = fun(mx.array(1.0), 1.0)
|
||||
self.assertEqual(z.item(), 2.0)
|
||||
|
||||
z = fun(mx.array(1.0), 2.0)
|
||||
self.assertEqual(z.item(), 3.0)
|
||||
|
||||
z = fun(mx.array(1.0), y=1.0)
|
||||
self.assertEqual(z.item(), 2.0)
|
||||
|
||||
z = fun(mx.array(1.0), y=3.0)
|
||||
self.assertEqual(z.item(), 4.0)
|
||||
|
||||
# Test tuple
|
||||
@partial(mx.compile)
|
||||
def fun(x, y=(1, 2)):
|
||||
return x + y[0] + y[1]
|
||||
|
||||
z = fun(mx.array(1))
|
||||
self.assertEqual(z.item(), 4)
|
||||
|
||||
z = fun(mx.array(1), (2, 2))
|
||||
self.assertEqual(z.item(), 5)
|
||||
|
||||
z = fun(mx.array(1), (2, 1))
|
||||
self.assertEqual(z.item(), 4)
|
||||
|
||||
# Test bool
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), True)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), False)
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
# Test string
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y == "one":
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), "one")
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), "two")
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -38,6 +38,17 @@ class TestDevice(mlx_tests.MLXTestCase):
|
||||
# Restore device
|
||||
mx.set_default_device(device)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_device_context(self):
|
||||
default = mx.default_device()
|
||||
diff = mx.cpu if default == mx.gpu else mx.gpu
|
||||
self.assertNotEqual(default, diff)
|
||||
with mx.stream(diff):
|
||||
a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2)))
|
||||
mx.eval(a)
|
||||
self.assertEqual(mx.default_device(), diff)
|
||||
self.assertEqual(mx.default_device(), default)
|
||||
|
||||
def test_op_on_device(self):
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
@@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
y = dfun_dx(mx.array(1.0))
|
||||
self.assertEqual(y.item(), 6.0)
|
||||
|
||||
def test_eval_mixed(self):
|
||||
x = mx.array(1) + 1 + 1
|
||||
y = 0
|
||||
z = "hello"
|
||||
state = [x, y, z]
|
||||
mx.eval(state)
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
158
python/tests/test_fast.py
Normal file
158
python/tests/test_fast.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
def rope_orig(x, dims, traditional, base, scale, offset):
|
||||
N = x.shape[1] + offset
|
||||
dtype = x.dtype
|
||||
half_D = dims // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||
if traditional:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
return mx.reshape(rx, x.shape)
|
||||
else:
|
||||
x1 = x[..., : dims // 2]
|
||||
x2 = x[..., dims // 2 : dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
if dims < x.shape[-1]:
|
||||
rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1)
|
||||
else:
|
||||
rx = mx.concatenate([rx1, rx2], axis=-1)
|
||||
return rx
|
||||
|
||||
|
||||
class TestFast(mlx_tests.MLXTestCase):
|
||||
def test_rope(self):
|
||||
T = 4
|
||||
|
||||
# Defaults: dims, dtype, base, scale, offset, traditional
|
||||
defaults = (8, mx.float32, 10000.0, 1.0, 0, False)
|
||||
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||
|
||||
# Test cases:
|
||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||
bases = [10000.0, 1000000.0]
|
||||
scales = [1.0, 2.0]
|
||||
offsets = [0, 3]
|
||||
traditional = [True, False]
|
||||
|
||||
for traditional in [True, False]:
|
||||
dims, dtype, _, scale, offset, _ = defaults
|
||||
for base in bases:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
dims, _, base, scale, offset, _ = defaults
|
||||
for dtype in dtypes:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
ry = rope_orig(
|
||||
x.astype(mx.float32), dims, traditional, base, scale, offset
|
||||
)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
if dtype != mx.float32:
|
||||
self.assertLessEqual(
|
||||
mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max()
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
dims, dtype, base, scale, _, _ = defaults
|
||||
for offset in offsets:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
dims, dtype, base, _, offset, _ = defaults
|
||||
for scale in scales:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_fast_transforms(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
|
||||
defaults = (8, False, 10000.0, 1.0, 0)
|
||||
dims, traditional, base, scale, offset = defaults
|
||||
|
||||
# VJP
|
||||
_, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
|
||||
_, vjp_fast_out = mx.vjp(
|
||||
lambda x: mx.fast.rope(
|
||||
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||
),
|
||||
(x,),
|
||||
(mx.ones_like(x),),
|
||||
)
|
||||
self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0]))
|
||||
|
||||
# JVP
|
||||
_, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
|
||||
_, jvp_fast_out = mx.jvp(
|
||||
lambda x: mx.fast.rope(
|
||||
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||
),
|
||||
(x,),
|
||||
(mx.ones_like(x),),
|
||||
)
|
||||
self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0]))
|
||||
|
||||
# VMAP
|
||||
x = mx.random.uniform(shape=(2, 2, 2, 8))
|
||||
vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x)
|
||||
vmap_fast_out = mx.vmap(
|
||||
lambda x: mx.fast.rope(
|
||||
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||
)
|
||||
)(x)
|
||||
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -19,72 +19,73 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
def test_fft(self):
|
||||
default = mx.default_device()
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def check_mx_np(op_mx, op_np, a_np, **kwargs):
|
||||
out_np = op_np(a_np, **kwargs)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, **kwargs)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
with mx.stream(mx.cpu):
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
|
||||
mx.set_default_device(default)
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
|
||||
def test_fftn(self):
|
||||
default = mx.default_device()
|
||||
mx.set_default_device(mx.cpu)
|
||||
with mx.stream(mx.cpu):
|
||||
r = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
i = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
a = r + 1j * i
|
||||
|
||||
r = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
i = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
a = r + 1j * i
|
||||
axes = [None, (1, 2), (2, 1), (0, 2)]
|
||||
shapes = [None, (10, 5), (5, 10)]
|
||||
ops = [
|
||||
"fft2",
|
||||
"ifft2",
|
||||
"rfft2",
|
||||
"irfft2",
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
]
|
||||
|
||||
axes = [None, (1, 2), (2, 1), (0, 2)]
|
||||
shapes = [None, (10, 5), (5, 10)]
|
||||
ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"]
|
||||
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
self.check_mx_np(op, x, axes=ax, s=s)
|
||||
|
||||
mx.set_default_device(default)
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
self.check_mx_np(op, x, axes=ax, s=s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
|
||||
mx.save_safetensors(
|
||||
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
)
|
||||
res = mx.load("test.safetensors", return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
|
||||
for dt in self.dtypes + ["bfloat16"]:
|
||||
with self.subTest(dtype=dt):
|
||||
@@ -75,9 +84,11 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
|
||||
)
|
||||
save_dict = {
|
||||
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
"test": (
|
||||
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
)
|
||||
}
|
||||
|
||||
with open(save_file_mlx, "wb") as f:
|
||||
@@ -104,9 +115,11 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
|
||||
)
|
||||
save_dict = {
|
||||
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
"test": (
|
||||
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
)
|
||||
}
|
||||
|
||||
mx.save_gguf(save_file_mlx, save_dict)
|
||||
|
||||
@@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertEqual(losses_sum, expected_sum)
|
||||
|
||||
# With weights, no label smoothing
|
||||
weights = mx.array([1.0, 2.0, 1.0, 2.0])
|
||||
expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944])
|
||||
loss = nn.losses.binary_cross_entropy(
|
||||
logits, targets, weights=weights, reduction="none"
|
||||
)
|
||||
self.assertTrue(mx.allclose(loss, expected))
|
||||
|
||||
def _test_probs_as_inputs():
|
||||
probs = mx.array([0.5, 0.6, 0.7, 0.8])
|
||||
targets = mx.array([0, 0, 1, 1])
|
||||
|
||||
@@ -71,7 +71,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_save_safetensors_weights(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2), nn.ReLU())
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
@@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
]
|
||||
)
|
||||
|
||||
def test_module_state(self):
|
||||
m = nn.Linear(10, 1)
|
||||
m.state["hello"] = "world"
|
||||
self.assertEqual(m.state["hello"], "world")
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
@@ -900,6 +905,347 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_pooling(self):
|
||||
# Test 1d pooling
|
||||
x = mx.array(
|
||||
[
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
|
||||
[[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]],
|
||||
]
|
||||
)
|
||||
expected_max_pool_output_no_padding_stride_1 = [
|
||||
[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
[[15.0, 16.0, 17.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
|
||||
]
|
||||
expected_max_pool_output_no_padding_stride_2 = [
|
||||
[[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]],
|
||||
[[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]],
|
||||
]
|
||||
expected_max_pool_output_padding_1_stride_2 = [
|
||||
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
[[12.0, 13.0, 14.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
|
||||
]
|
||||
expected_max_pool_output_padding_1_stride_2_kernel_3 = [
|
||||
[[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]],
|
||||
[[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]],
|
||||
]
|
||||
expected_avg_pool_output_no_padding_stride_1 = [
|
||||
[
|
||||
[1.5000, 2.5000, 3.5000],
|
||||
[4.5000, 5.5000, 6.5000],
|
||||
[7.5000, 8.5000, 9.5000],
|
||||
],
|
||||
[
|
||||
[13.5000, 14.5000, 15.5000],
|
||||
[16.5000, 17.5000, 18.5000],
|
||||
[19.5000, 20.5000, 21.5000],
|
||||
],
|
||||
]
|
||||
expected_avg_pool_output_no_padding_stride_2 = [
|
||||
[[1.5000, 2.5000, 3.5000], [7.5000, 8.5000, 9.5000]],
|
||||
[[13.5000, 14.5000, 15.5000], [19.5000, 20.5000, 21.5000]],
|
||||
]
|
||||
expected_avg_pool_output_padding_1_stride_2 = [
|
||||
[
|
||||
[0.0000, 0.5000, 1.0000],
|
||||
[4.5000, 5.5000, 6.5000],
|
||||
[4.5000, 5.0000, 5.5000],
|
||||
],
|
||||
[
|
||||
[6.0000, 6.5000, 7.0000],
|
||||
[16.5000, 17.5000, 18.5000],
|
||||
[10.5000, 11.0000, 11.5000],
|
||||
],
|
||||
]
|
||||
expected_avg_pool_output_padding_1_kernel_3 = [
|
||||
[[1, 1.66667, 2.33333], [6, 7, 8]],
|
||||
[[9, 9.66667, 10.3333], [18, 19, 20]],
|
||||
]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool1d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool1d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool1d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_max_pool_output_padding_1_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool1d(kernel_size=3, stride=2, padding=1)(x),
|
||||
expected_max_pool_output_padding_1_stride_2_kernel_3,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool1d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_avg_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool1d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_avg_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool1d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_avg_pool_output_padding_1_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool1d(kernel_size=3, stride=2, padding=1)(x),
|
||||
expected_avg_pool_output_padding_1_kernel_3,
|
||||
)
|
||||
)
|
||||
# Test 2d pooling
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[[0, 16], [1, 17], [2, 18], [3, 19]],
|
||||
[[4, 20], [5, 21], [6, 22], [7, 23]],
|
||||
[[8, 24], [9, 25], [10, 26], [11, 27]],
|
||||
[[12, 28], [13, 29], [14, 30], [15, 31]],
|
||||
]
|
||||
]
|
||||
)
|
||||
expected_max_pool_output_no_padding_stride_1 = [
|
||||
[
|
||||
[[5, 21], [6, 22], [7, 23]],
|
||||
[[9, 25], [10, 26], [11, 27]],
|
||||
[[13, 29], [14, 30], [15, 31]],
|
||||
]
|
||||
]
|
||||
expected_max_pool_output_no_padding_stride_2 = [
|
||||
[[[5, 21], [7, 23]], [[13, 29], [15, 31]]]
|
||||
]
|
||||
expected_max_pool_output_padding_1 = [
|
||||
[
|
||||
[[0, 16], [2, 18], [3, 19]],
|
||||
[[8, 24], [10, 26], [11, 27]],
|
||||
[[12, 28], [14, 30], [15, 31]],
|
||||
]
|
||||
]
|
||||
expected_mean_pool_output_no_padding_stride_1 = [
|
||||
[
|
||||
[[2.5000, 18.5000], [3.5000, 19.5000], [4.5000, 20.5000]],
|
||||
[[6.5000, 22.5000], [7.5000, 23.5000], [8.5000, 24.5000]],
|
||||
[[10.5000, 26.5000], [11.5000, 27.5000], [12.5000, 28.5000]],
|
||||
]
|
||||
]
|
||||
expected_mean_pool_output_no_padding_stride_2 = [
|
||||
[
|
||||
[[2.5000, 18.5000], [4.5000, 20.5000]],
|
||||
[[10.5000, 26.5000], [12.5000, 28.5000]],
|
||||
]
|
||||
]
|
||||
expected_mean_pool_output_padding_1 = [
|
||||
[
|
||||
[[0.0000, 4.0000], [0.7500, 8.7500], [0.7500, 4.7500]],
|
||||
[[3.0000, 11.0000], [7.5000, 23.5000], [4.5000, 12.5000]],
|
||||
[[3.0000, 7.0000], [6.7500, 14.7500], [3.7500, 7.7500]],
|
||||
]
|
||||
]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_max_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_max_pool_output_padding_1,
|
||||
)
|
||||
)
|
||||
# Average pooling
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool2d(kernel_size=2, stride=1, padding=0)(x),
|
||||
expected_mean_pool_output_no_padding_stride_1,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(x),
|
||||
expected_mean_pool_output_no_padding_stride_2,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=1)(x),
|
||||
expected_mean_pool_output_padding_1,
|
||||
)
|
||||
)
|
||||
# Test multiple batches
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[[0, 1], [2, 3], [4, 5], [6, 7]],
|
||||
[[8, 9], [10, 11], [12, 13], [14, 15]],
|
||||
[[16, 17], [18, 19], [20, 21], [22, 23]],
|
||||
[[24, 25], [26, 27], [28, 29], [30, 31]],
|
||||
],
|
||||
[
|
||||
[[32, 33], [34, 35], [36, 37], [38, 39]],
|
||||
[[40, 41], [42, 43], [44, 45], [46, 47]],
|
||||
[[48, 49], [50, 51], [52, 53], [54, 55]],
|
||||
[[56, 57], [58, 59], [60, 61], [62, 63]],
|
||||
],
|
||||
]
|
||||
)
|
||||
expected_max_pool_output = [
|
||||
[[[10.0, 11.0], [14.0, 15.0]], [[26.0, 27.0], [30.0, 31.0]]],
|
||||
[[[42.0, 43.0], [46.0, 47.0]], [[58.0, 59.0], [62.0, 63.0]]],
|
||||
]
|
||||
expected_avg_pool_output = [
|
||||
[[[2.22222, 2.66667], [5.33333, 6]], [[11.3333, 12], [20, 21]]],
|
||||
[[[16.4444, 16.8889], [26.6667, 27.3333]], [[32.6667, 33.3333], [52, 53]]],
|
||||
]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x),
|
||||
expected_max_pool_output,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1)(x),
|
||||
expected_avg_pool_output,
|
||||
)
|
||||
)
|
||||
# Test irregular kernel (2, 4), stride (3, 1) and padding (1, 2)
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
|
||||
[[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]],
|
||||
[[24, 25, 26], [27, 28, 29], [30, 31, 32], [33, 34, 35]],
|
||||
[[36, 37, 38], [39, 40, 41], [42, 43, 44], [45, 46, 47]],
|
||||
],
|
||||
[
|
||||
[[48, 49, 50], [51, 52, 53], [54, 55, 56], [57, 58, 59]],
|
||||
[[60, 61, 62], [63, 64, 65], [66, 67, 68], [69, 70, 71]],
|
||||
[[72, 73, 74], [75, 76, 77], [78, 79, 80], [81, 82, 83]],
|
||||
[[84, 85, 86], [87, 88, 89], [90, 91, 92], [93, 94, 95]],
|
||||
],
|
||||
]
|
||||
)
|
||||
expected_irregular_max_pool_output = [
|
||||
[
|
||||
[
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0],
|
||||
[9.0, 10.0, 11.0],
|
||||
[9.0, 10.0, 11.0],
|
||||
],
|
||||
[
|
||||
[39.0, 40.0, 41.0],
|
||||
[42.0, 43.0, 44.0],
|
||||
[45.0, 46.0, 47.0],
|
||||
[45.0, 46.0, 47.0],
|
||||
[45.0, 46.0, 47.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[51.0, 52.0, 53.0],
|
||||
[54.0, 55.0, 56.0],
|
||||
[57.0, 58.0, 59.0],
|
||||
[57.0, 58.0, 59.0],
|
||||
[57.0, 58.0, 59.0],
|
||||
],
|
||||
[
|
||||
[87.0, 88.0, 89.0],
|
||||
[90.0, 91.0, 92.0],
|
||||
[93.0, 94.0, 95.0],
|
||||
[93.0, 94.0, 95.0],
|
||||
[93.0, 94.0, 95.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
expected_irregular_average_pool_output = [
|
||||
[
|
||||
[
|
||||
[0.3750, 0.6250, 0.8750],
|
||||
[1.1250, 1.5000, 1.8750],
|
||||
[2.2500, 2.7500, 3.2500],
|
||||
[2.2500, 2.6250, 3.0000],
|
||||
[1.8750, 2.1250, 2.3750],
|
||||
],
|
||||
[
|
||||
[15.7500, 16.2500, 16.7500],
|
||||
[24.7500, 25.5000, 26.2500],
|
||||
[34.5000, 35.5000, 36.5000],
|
||||
[27.0000, 27.7500, 28.5000],
|
||||
[18.7500, 19.2500, 19.7500],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[12.3750, 12.6250, 12.8750],
|
||||
[19.1250, 19.5000, 19.8750],
|
||||
[26.2500, 26.7500, 27.2500],
|
||||
[20.2500, 20.6250, 21.0000],
|
||||
[13.8750, 14.1250, 14.3750],
|
||||
],
|
||||
[
|
||||
[39.7500, 40.2500, 40.7500],
|
||||
[60.7500, 61.5000, 62.2500],
|
||||
[82.5000, 83.5000, 84.5000],
|
||||
[63.0000, 63.7500, 64.5000],
|
||||
[42.7500, 43.2500, 43.7500],
|
||||
],
|
||||
],
|
||||
]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
nn.MaxPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x),
|
||||
expected_irregular_max_pool_output,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
nn.AvgPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x),
|
||||
expected_irregular_average_pool_output,
|
||||
)
|
||||
)
|
||||
# Test repr
|
||||
self.assertEqual(
|
||||
str(nn.MaxPool1d(kernel_size=3, padding=2)),
|
||||
"MaxPool1d(kernel_size=(3,), stride=(3,), padding=(2,))",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.AvgPool1d(kernel_size=2, stride=3)),
|
||||
"AvgPool1d(kernel_size=(2,), stride=(3,), padding=(0,))",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
|
||||
"MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
|
||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
@@ -274,6 +275,20 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(z.dtype, dt)
|
||||
self.assertEqual(z.item(), 1)
|
||||
|
||||
z = -1 % x
|
||||
self.assertEqual(z.dtype, dt)
|
||||
self.assertEqual(z.item(), 1)
|
||||
|
||||
z = -1 % -x
|
||||
self.assertEqual(z.dtype, dt)
|
||||
self.assertEqual(z.item(), -1)
|
||||
|
||||
x = mx.arange(10).astype(dt) - 5
|
||||
y = x % 5
|
||||
z = x % -5
|
||||
self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
|
||||
self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1])
|
||||
|
||||
def test_comparisons(self):
|
||||
a = mx.array([0.0, 1.0, 5.0])
|
||||
b = mx.array([-1.0, 2.0, 5.0])
|
||||
@@ -1012,6 +1027,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.tolist(), [[3, 4]])
|
||||
self.assertEqual(z.tolist(), [[5, 6]])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.split(a, 3, axis=2)
|
||||
|
||||
a = mx.arange(8)
|
||||
x, y, z = mx.split(a, [1, 5])
|
||||
self.assertEqual(x.tolist(), [0])
|
||||
@@ -1318,9 +1336,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for d in dims:
|
||||
anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d)
|
||||
for n_bsx in range(d):
|
||||
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape(
|
||||
[size] * n_bsx
|
||||
)
|
||||
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx)
|
||||
for _ in range(trial_mul * d):
|
||||
amlx = mx.array(anp)
|
||||
bmlx = mx.array(bnp)
|
||||
@@ -1371,6 +1387,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue((a[:-1] < 1e-9).all())
|
||||
self.assertEqual(a[-1], 1)
|
||||
|
||||
# Sliced inputs
|
||||
y = mx.random.uniform(shape=(8, 4))
|
||||
out = mx.softmax(y[:, 0:2], axis=-1)
|
||||
self.assertAlmostEqual(out.sum().item(), 8.0)
|
||||
|
||||
def test_concatenate(self):
|
||||
a_npy = np.random.randn(32, 32, 32)
|
||||
b_npy = np.random.randn(32, 32, 32)
|
||||
@@ -1566,6 +1587,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
d_np = np.take(b_mx, np.arange(kth), axis=axis)
|
||||
self.assertTrue(np.all(d_np <= c_mx))
|
||||
|
||||
@unittest.skipIf(
|
||||
os.getenv("LOW_MEMORY", None) is not None,
|
||||
"This test requires a lot of memory",
|
||||
)
|
||||
def test_large_binary(self):
|
||||
a = mx.ones([1000, 2147484], mx.int8)
|
||||
b = mx.ones([2147484], mx.int8)
|
||||
@@ -1677,6 +1702,8 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
def test_repeat(self):
|
||||
# Setup data for the tests
|
||||
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
||||
# Test repeat 0 times
|
||||
self.assertCmpNumpy([data, 0], mx.repeat, np.repeat)
|
||||
# Test repeat along axis 0
|
||||
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
|
||||
# Test repeat along axis 1
|
||||
@@ -1856,6 +1883,96 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = mx.array(np.diag(x, k=-1))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
def test_atleast_1d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[[1], [2], [3]],
|
||||
[[1, 2], [3, 4]],
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[[1], [2], [3]],
|
||||
[[1, 2], [3, 4]],
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[[1], [2], [3]],
|
||||
[[1, 2], [3, 4]],
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,50 +1,215 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
def get_all_optimizers():
|
||||
classes = dict()
|
||||
for name, obj in inspect.getmembers(opt):
|
||||
if inspect.isclass(obj):
|
||||
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
|
||||
classes[name] = obj
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, opt.Optimizer)
|
||||
and obj != opt.Optimizer
|
||||
):
|
||||
classes[name] = obj
|
||||
return classes
|
||||
|
||||
|
||||
def tree_equal(fn, *args):
|
||||
return all(v for _, v in tree_flatten(tree_map(fn, *args)))
|
||||
|
||||
|
||||
optimizers_dict = get_all_optimizers()
|
||||
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
def test_optimizer_state(self):
|
||||
optim = opt.SGD(0.1)
|
||||
optim.state["hello"] = "world"
|
||||
self.assertEqual(optim.state["hello"], "world")
|
||||
|
||||
optim.state = {0: 1}
|
||||
self.assertEqual(optim.state, {0: 1})
|
||||
|
||||
def test_optimizers(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for optim_class in optimizers_dict.values():
|
||||
optim = optim_class(0.1)
|
||||
update = optim.apply_gradients(grads, params)
|
||||
mx.eval(update)
|
||||
equal_shape = mlx.utils.tree_map(
|
||||
lambda x, y: x.shape == y.shape, params, update
|
||||
)
|
||||
equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)
|
||||
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||
self.assertTrue(all_equal)
|
||||
|
||||
def test_types_conserved(self):
|
||||
params = {"w": mx.ones((5, 5), mx.float16)}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
for optim_class in optimizers_dict.values():
|
||||
optim = optim_class(0.1)
|
||||
update = optim.apply_gradients(grads, params)
|
||||
self.assertEqual(update["w"].dtype, mx.float16)
|
||||
|
||||
def test_sgd(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Implicit init
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
optim.apply_gradients(grads, params)
|
||||
self.assertTrue(
|
||||
tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state)
|
||||
)
|
||||
|
||||
def test_rmsprop(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.RMSprop(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Implicit init
|
||||
alpha = 0.99
|
||||
optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)
|
||||
optim.apply_gradients(grads, params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state
|
||||
)
|
||||
)
|
||||
|
||||
def test_adagrad(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Adagrad(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adadelta(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.AdaDelta(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adam(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:
|
||||
optim = optimizer(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_lion(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Lion(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adafactor(self):
|
||||
x = mx.zeros((5, 5))
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
xp = optimizer.apply_gradients(grad, x)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
|
||||
@@ -52,11 +217,129 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
xp = optimizer.apply_gradients(grad, x)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
def test_compiled_optimizer(self):
|
||||
model = nn.Linear(10, 10)
|
||||
x = mx.random.uniform(shape=(2, 10))
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
|
||||
orig_params = model.parameters()
|
||||
|
||||
def loss(model, x):
|
||||
return model(x).sum()
|
||||
|
||||
# Uncompiled version
|
||||
def step(x):
|
||||
_, grad = nn.value_and_grad(model, loss)(model, x)
|
||||
optim.update(model, grad)
|
||||
|
||||
step(x)
|
||||
uncompiled_params = model.parameters()
|
||||
|
||||
# Pure version
|
||||
def loss(params, x):
|
||||
model.update(params)
|
||||
return model(x).sum()
|
||||
|
||||
model.update(orig_params)
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
|
||||
@mx.compile
|
||||
def step(params, opt_state, x):
|
||||
grad = mx.grad(loss)(params, x)
|
||||
optim.state = opt_state
|
||||
params = optim.apply_gradients(grad, params)
|
||||
return params, optim.state
|
||||
|
||||
optim.init(model.parameters())
|
||||
pure_params, _ = step(model.parameters(), optim.state, x)
|
||||
self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"]))
|
||||
self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"]))
|
||||
|
||||
# Impure version
|
||||
def loss(model, x):
|
||||
return model(x).sum()
|
||||
|
||||
model.update(orig_params)
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
state = [model.state, optim.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x):
|
||||
_, grad = nn.value_and_grad(model, loss)(model, x)
|
||||
optim.update(model, grad)
|
||||
|
||||
step(x)
|
||||
impure_params = model.parameters()
|
||||
self.assertTrue(
|
||||
mx.allclose(impure_params["weight"], uncompiled_params["weight"])
|
||||
)
|
||||
self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"]))
|
||||
|
||||
def test_update_lr_compiled(self):
|
||||
params = {"w": mx.ones((5, 5))}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
optim = opt.SGD(-1.0)
|
||||
|
||||
@partial(mx.compile, inputs=optim.state)
|
||||
def update(grads):
|
||||
return optim.apply_gradients(grads, params)
|
||||
|
||||
result = update(grads)
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0)))
|
||||
optim.learning_rate = -2.0
|
||||
result = update(grads)
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||
|
||||
|
||||
class TestSchedulers(unittest.TestCase):
|
||||
def test_decay_lr(self):
|
||||
for optim_class in optimizers_dict.values():
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||
optimizer = optim_class(learning_rate=lr_schedule)
|
||||
|
||||
params = {"w": mx.ones((5, 5))}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for it in range(10):
|
||||
expected_lr = 0.1 * (0.9**it)
|
||||
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
|
||||
return optimizer.apply_gradients(grads, params)
|
||||
|
||||
def test_step_decay(self):
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||
lr = lr_schedule(2500)
|
||||
expected_lr = 0.1 * (0.9**2)
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_exponential_decay(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.99)
|
||||
lr = lr_schedule(10)
|
||||
expected_lr = 0.1 * (0.99**10)
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_cosine_decay(self):
|
||||
lr_schedule = opt.cosine_decay(0.1, 10)
|
||||
lr = lr_schedule(4)
|
||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_compile_with_schedule(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
||||
optimizer = opt.SGD(learning_rate=lr_schedule)
|
||||
|
||||
@partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
|
||||
def update():
|
||||
optimizer.update({}, {})
|
||||
|
||||
for step in range(5):
|
||||
update()
|
||||
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -165,6 +165,70 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_non_multiples(self):
|
||||
w = mx.random.normal(shape=(33, 256))
|
||||
w_q, scales, biases = mx.quantize(w)
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(10, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qvm
|
||||
x = mx.random.normal(shape=(1, 33))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm
|
||||
x = mx.random.normal(shape=(10, 33))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Smaller than 8
|
||||
w = mx.random.normal(shape=(3, 256))
|
||||
w_q, scales, biases = mx.quantize(w)
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(10, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qvm
|
||||
x = mx.random.normal(shape=(1, 3))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm
|
||||
x = mx.random.normal(shape=(10, 3))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -80,6 +80,20 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
a = mx.random.normal(dtype=t)
|
||||
self.assertEqual(a.dtype, t)
|
||||
|
||||
# Generate with a given mean and standard deviation
|
||||
loc = 1.0
|
||||
scale = 2.0
|
||||
|
||||
a = mx.random.normal(shape=(3, 2), loc=loc, scale=scale, key=key)
|
||||
b = scale * mx.random.normal(shape=(3, 2), key=key) + loc
|
||||
self.assertTrue(mx.allclose(a, b))
|
||||
|
||||
a = mx.random.normal(
|
||||
shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key
|
||||
)
|
||||
b = scale * mx.random.normal(shape=(3, 2), dtype=mx.float16, key=key) + loc
|
||||
self.assertTrue(mx.allclose(a, b))
|
||||
|
||||
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
|
||||
|
||||
def test_randint(self):
|
||||
|
||||
Reference in New Issue
Block a user