mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'ml-explore:main' into main
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import operator
|
||||
import pickle
|
||||
import unittest
|
||||
import weakref
|
||||
from itertools import permutations
|
||||
@@ -1440,6 +1441,15 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
b @= a
|
||||
self.assertTrue(mx.array_equal(a, b))
|
||||
|
||||
def test_load_from_pickled_np(self):
|
||||
a = np.array([1, 2, 3], dtype=np.int32)
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||
|
||||
a = np.array([1.0, 2.0, 3.0], dtype=np.float16)
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -415,6 +415,14 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||
self.assertEqual(vjps[0].item(), 8.0)
|
||||
|
||||
def test_power_grad(self):
|
||||
def fun(x, y):
|
||||
res = x - y
|
||||
return res**x
|
||||
|
||||
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
|
||||
self.assertEqual(grad.item(), 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -539,6 +539,72 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
z = fun(mx.array(1), "two")
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
# Test nested constant
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y[0][0] == 1:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), [[1]])
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), [[0]])
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
@partial(mx.compile)
|
||||
def fun(x, a, b):
|
||||
for ai in a:
|
||||
for bi in b:
|
||||
x = bi * x + ai
|
||||
return x
|
||||
|
||||
z = fun(mx.array(1), [1, 1], [2])
|
||||
self.assertEqual(z.item(), 7)
|
||||
|
||||
z = fun(mx.array(1), [1], [1, 2])
|
||||
self.assertEqual(z.item(), 5)
|
||||
|
||||
counter = [0]
|
||||
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
counter[0] += 1
|
||||
return x + y
|
||||
|
||||
z = fun(mx.array(1), 1)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(1, mx.array(1))
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
def test_compile_inf(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return mx.isinf(x + 2)
|
||||
|
||||
out = fun(mx.array([0.0]))
|
||||
self.assertEqual(out.item(), False)
|
||||
|
||||
def test_unsupported_input_types(self):
|
||||
|
||||
class MyClass:
|
||||
value = 1
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
return x + y.value
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = fun(mx.array(0.0), MyClass())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = fun(mx.array(0.0), y=MyClass())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
@@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
@@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
|
||||
):
|
||||
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
run_conv2D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
def __conv_general_test(
|
||||
self,
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride=1,
|
||||
padding=0,
|
||||
kernel_dilation=1,
|
||||
input_dilation=1,
|
||||
groups=1,
|
||||
flip=False,
|
||||
np_dtype=np.float32,
|
||||
atol=1e-5,
|
||||
):
|
||||
|
||||
with self.subTest(
|
||||
in_shape=in_shape,
|
||||
wt_shape=wt_shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
np_dtype=np_dtype,
|
||||
):
|
||||
|
||||
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
||||
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
out_mx = mx.conv_general(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
def conv_general_pt(
|
||||
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
|
||||
):
|
||||
|
||||
C = inp.size()[1]
|
||||
ndim = inp.ndim - 2
|
||||
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
|
||||
|
||||
stride, padding, kernel_dilation, input_dilation = map(
|
||||
map_ints, (stride, padding, kernel_dilation, input_dilation)
|
||||
)
|
||||
|
||||
torch_convt_list = (
|
||||
F.conv_transpose1d,
|
||||
F.conv_transpose2d,
|
||||
F.conv_transpose3d,
|
||||
)
|
||||
torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)
|
||||
|
||||
conv_f = torch_conv_list[ndim - 1]
|
||||
convt_f = torch_convt_list[ndim - 1]
|
||||
|
||||
if flip:
|
||||
wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))
|
||||
|
||||
if not np.all(input_dilation == 1):
|
||||
ones = torch.ones(
|
||||
[C]
|
||||
+ [
|
||||
1,
|
||||
]
|
||||
* (ndim + 1)
|
||||
).to(inp.dtype)
|
||||
inp = convt_f(inp, ones, stride=input_dilation, groups=C)
|
||||
|
||||
return conv_f(
|
||||
inp,
|
||||
wt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=kernel_dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
out_pt = conv_general_pt(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
out_pt = np.moveaxis(out_pt.numpy(), 1, -1)
|
||||
|
||||
self.assertEqual(out_mx.shape, out_pt.shape)
|
||||
self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_general(self):
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 5, 16)
|
||||
stride = (1, 1)
|
||||
padding = (2, 2)
|
||||
kernel_dilation = (2, 3)
|
||||
input_dilation = (1, 1)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 3)
|
||||
padding = (0, 0)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 4)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 2)
|
||||
padding = (3, 2)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 4)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 3)
|
||||
padding = (3, 2)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 5)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 5, 16)
|
||||
stride = (2, 3)
|
||||
padding = (0, 0)
|
||||
kernel_dilation = (3, 1)
|
||||
input_dilation = (2, 5)
|
||||
flip = True
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -66,13 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
|
||||
mx.save_safetensors(
|
||||
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
)
|
||||
res = mx.load("test.safetensors", return_metadata=True)
|
||||
res = mx.load(test_file, return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
|
||||
|
||||
45
python/tests/test_metal.py
Normal file
45
python/tests/test_metal.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMetal(mlx_tests.MLXTestCase):
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.metal.set_cache_limit(0)
|
||||
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
del a
|
||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
|
||||
|
||||
old_limit = mx.metal.set_memory_limit(10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
active_mem = mx.metal.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
mx.eval(b)
|
||||
del b
|
||||
|
||||
new_active_mem = mx.metal.get_active_memory()
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.metal.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
@@ -8,7 +8,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
class TestBase(mlx_tests.MLXTestCase):
|
||||
@@ -665,7 +665,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
|
||||
|
||||
def test_sin_pe(self):
|
||||
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||
@@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_upsample(self):
|
||||
b, h, w, c = 1, 2, 2, 1
|
||||
scale_factor = 2
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=True
|
||||
)
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear_no_align_corners = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=False
|
||||
)
|
||||
upsample_nearest_no_align_corners = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=False
|
||||
)
|
||||
# Test single feature map, align corners
|
||||
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
expected_nearest = mx.array(
|
||||
[[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0.333333, 0.666667, 1],
|
||||
[0.666667, 1, 1.33333, 1.66667],
|
||||
[1.33333, 1.66667, 2, 2.33333],
|
||||
[2, 2.33333, 2.66667, 3],
|
||||
]
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
# Test single feature map, no align corners
|
||||
x = (
|
||||
mx.arange(1, b * h * w * c + 1)
|
||||
.reshape((b, c, h, w))
|
||||
.transpose((0, 2, 3, 1))
|
||||
)
|
||||
expected_bilinear_no_align_corners = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1.0000, 1.2500, 1.7500, 2.0000],
|
||||
[1.5000, 1.7500, 2.2500, 2.5000],
|
||||
[2.5000, 2.7500, 3.2500, 3.5000],
|
||||
[3.0000, 3.2500, 3.7500, 4.0000],
|
||||
]
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_nearest_no_align_corners = mx.array(
|
||||
[[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
upsample_bilinear_no_align_corners(x),
|
||||
expected_bilinear_no_align_corners,
|
||||
)
|
||||
)
|
||||
|
||||
# Test a more complex batch
|
||||
b, h, w, c = 2, 3, 3, 2
|
||||
scale_factor = 2
|
||||
x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=True
|
||||
)
|
||||
|
||||
expected_nearest = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
|
||||
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
|
||||
],
|
||||
[
|
||||
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
|
||||
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
|
||||
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
|
||||
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
|
||||
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
|
||||
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
|
||||
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
|
||||
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
|
||||
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
|
||||
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
|
||||
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
|
||||
],
|
||||
[
|
||||
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
|
||||
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
|
||||
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
|
||||
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
|
||||
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
|
||||
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0, 0.4, 0.8, 1.2, 1.6, 2.0],
|
||||
[1.2, 1.6, 2.0, 2.4, 2.8, 3.2],
|
||||
[2.4, 2.8, 3.2, 3.6, 4.0, 4.4],
|
||||
[3.6, 4.0, 4.4, 4.8, 5.2, 5.6],
|
||||
[4.8, 5.2, 5.6, 6.0, 6.4, 6.8],
|
||||
[6.0, 6.4, 6.8, 7.2, 7.6, 8.0],
|
||||
],
|
||||
[
|
||||
[9.0, 9.4, 9.8, 10.2, 10.6, 11.0],
|
||||
[10.2, 10.6, 11.0, 11.4, 11.8, 12.2],
|
||||
[11.4, 11.8, 12.2, 12.6, 13.0, 13.4],
|
||||
[12.6, 13.0, 13.4, 13.8, 14.2, 14.6],
|
||||
[13.8, 14.2, 14.6, 15.0, 15.4, 15.8],
|
||||
[15.0, 15.4, 15.8, 16.2, 16.6, 17.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[18.0, 18.4, 18.8, 19.2, 19.6, 20.0],
|
||||
[19.2, 19.6, 20.0, 20.4, 20.8, 21.2],
|
||||
[20.4, 20.8, 21.2, 21.6, 22.0, 22.4],
|
||||
[21.6, 22.0, 22.4, 22.8, 23.2, 23.6],
|
||||
[22.8, 23.2, 23.6, 24.0, 24.4, 24.8],
|
||||
[24.0, 24.4, 24.8, 25.2, 25.6, 26.0],
|
||||
],
|
||||
[
|
||||
[27.0, 27.4, 27.8, 28.2, 28.6, 29.0],
|
||||
[28.2, 28.6, 29.0, 29.4, 29.8, 30.2],
|
||||
[29.4, 29.8, 30.2, 30.6, 31.0, 31.4],
|
||||
[30.6, 31.0, 31.4, 31.8, 32.2, 32.6],
|
||||
[31.8, 32.2, 32.6, 33.0, 33.4, 33.8],
|
||||
[33.0, 33.4, 33.8, 34.2, 34.6, 35.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
|
||||
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
|
||||
|
||||
# Test different height and width scale_factor
|
||||
b, h, w, c = 1, 2, 2, 2
|
||||
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=(2, 3), mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=(2, 3), mode="linear", align_corners=True
|
||||
)
|
||||
|
||||
expected_nearest = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[2, 2, 2, 3, 3, 3],
|
||||
[2, 2, 2, 3, 3, 3],
|
||||
],
|
||||
[
|
||||
[4, 4, 4, 5, 5, 5],
|
||||
[4, 4, 4, 5, 5, 5],
|
||||
[6, 6, 6, 7, 7, 7],
|
||||
[6, 6, 6, 7, 7, 7],
|
||||
],
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0.2, 0.4, 0.6, 0.8, 1],
|
||||
[0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],
|
||||
[1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],
|
||||
[2, 2.2, 2.4, 2.6, 2.8, 3],
|
||||
],
|
||||
[
|
||||
[4, 4.2, 4.4, 4.6, 4.8, 5],
|
||||
[4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],
|
||||
[5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],
|
||||
[6, 6.2, 6.4, 6.6, 6.8, 7],
|
||||
],
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
|
||||
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
|
||||
|
||||
# Test repr
|
||||
self.assertEqual(
|
||||
str(nn.Upsample(scale_factor=2)),
|
||||
"Upsample(scale_factor=2.0, mode='nearest', align_corners=False)",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.Upsample(scale_factor=(2, 3))),
|
||||
"Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)",
|
||||
)
|
||||
|
||||
def test_pooling(self):
|
||||
# Test 1d pooling
|
||||
x = mx.array(
|
||||
|
||||
@@ -1047,6 +1047,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.arange(0, float("inf"), float("inf"))
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.arange(float("inf"), 1, float("inf"))
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.arange(float("inf"), 1, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
INT_MAX = 2147483647
|
||||
a = mx.arange(0, INT_MAX + 1, 1)
|
||||
|
||||
a = mx.arange(5)
|
||||
expected = [0, 1, 2, 3, 4]
|
||||
@@ -1132,6 +1137,27 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
a = mx.arange(0, 10, 100)
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
a = mx.arange(10, 0, 1)
|
||||
expected = []
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(10, 0, float("inf"))
|
||||
expected = []
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(0, 10, float("inf"))
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(0, -10, float("-inf"))
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
def test_unary_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
r_np = npop(x)
|
||||
@@ -1563,7 +1589,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
shape = (3, 4, 5)
|
||||
for dtype in ("int32", "float32"):
|
||||
for axis in (None, 0, 1, 2):
|
||||
for kth in (-2, 2):
|
||||
for kth in (-2, 0, 2):
|
||||
with self.subTest(dtype=dtype, axis=axis, kth=kth):
|
||||
np.random.seed(0)
|
||||
np_dtype = getattr(np, dtype)
|
||||
@@ -1579,13 +1605,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(c_np, c_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
top_k_mx = mx.topk(a_mx, kth, axis=axis)
|
||||
self.assertTrue(np.all(c_np <= top_k_mx))
|
||||
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
|
||||
|
||||
if kth >= 0:
|
||||
d_np = np.take(b_mx, np.arange(kth), axis=axis)
|
||||
self.assertTrue(np.all(d_np <= c_mx))
|
||||
top_k_mx = mx.topk(a_mx, kth, axis=axis)
|
||||
top_k_np = np.take(
|
||||
np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis
|
||||
)
|
||||
self.assertTrue(np.all(top_k_np <= top_k_mx))
|
||||
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
|
||||
N = a_mx.shape[axis] if axis is not None else a_mx.size
|
||||
M = top_k_mx.shape[axis or 0]
|
||||
self.assertEqual(M, (kth + N) % N)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.getenv("LOW_MEMORY", None) is not None,
|
||||
@@ -1906,12 +1935,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_1d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@@ -1936,12 +1969,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_2d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@@ -1966,12 +2003,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_3d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase):
|
||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_schedule_joiner(self):
|
||||
boundaries = [2, 3, 4]
|
||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||
with self.assertRaises(ValueError):
|
||||
opt.schedulers.join_schedules(schedules, boundaries)
|
||||
boundaries = [2, 4]
|
||||
schedule = opt.schedulers.join_schedules(schedules, boundaries)
|
||||
self.assertEqual(schedule(0).item(), 3)
|
||||
self.assertEqual(schedule(1).item(), 3)
|
||||
self.assertEqual(schedule(2).item(), 4)
|
||||
self.assertEqual(schedule(3).item(), 4)
|
||||
self.assertEqual(schedule(5).item(), 5)
|
||||
self.assertEqual(schedule(7).item(), 5)
|
||||
|
||||
def test_linear_warmup_with_cosine_decay(self):
|
||||
warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
|
||||
cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
|
||||
cos_with_warmup = opt.schedulers.join_schedules(
|
||||
[warmup_schedule, cosine_schedule], [101]
|
||||
)
|
||||
self.assertEqual(cos_with_warmup(0), 0.0)
|
||||
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
|
||||
optimizer = opt.Adam(learning_rate=cos_with_warmup)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
|
||||
|
||||
def test_compile_with_schedule(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
||||
optimizer = opt.SGD(learning_rate=lr_schedule)
|
||||
|
||||
Reference in New Issue
Block a user