mlx/python/tests/test_array.py
2024-12-06 10:48:58 -08:00

1956 lines
63 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import gc
import operator
import pickle
import resource
import sys
import unittest
import weakref
from copy import copy, deepcopy
from itertools import permutations
import mlx.core as mx
import mlx_tests
import numpy as np
try:
import tensorflow as tf
has_tf = True
except ImportError as e:
has_tf = False
class TestVersion(mlx_tests.MLXTestCase):
def test_version(self):
v = mx.__version__
vnums = v.split(".")
self.assertGreaterEqual(len(vnums), 3)
v = ".".join(str(int(vn)) for vn in vnums[:3])
self.assertEqual(v, mx.__version__[: len(v)])
class TestDtypes(mlx_tests.MLXTestCase):
def test_dtypes(self):
self.assertEqual(mx.bool_.size, 1)
self.assertEqual(mx.uint8.size, 1)
self.assertEqual(mx.uint16.size, 2)
self.assertEqual(mx.uint32.size, 4)
self.assertEqual(mx.uint64.size, 8)
self.assertEqual(mx.int8.size, 1)
self.assertEqual(mx.int16.size, 2)
self.assertEqual(mx.int32.size, 4)
self.assertEqual(mx.int64.size, 8)
self.assertEqual(mx.float16.size, 2)
self.assertEqual(mx.float32.size, 4)
self.assertEqual(mx.bfloat16.size, 2)
self.assertEqual(mx.complex64.size, 8)
self.assertEqual(str(mx.bool_), "mlx.core.bool")
self.assertEqual(str(mx.uint8), "mlx.core.uint8")
self.assertEqual(str(mx.uint16), "mlx.core.uint16")
self.assertEqual(str(mx.uint32), "mlx.core.uint32")
self.assertEqual(str(mx.uint64), "mlx.core.uint64")
self.assertEqual(str(mx.int8), "mlx.core.int8")
self.assertEqual(str(mx.int16), "mlx.core.int16")
self.assertEqual(str(mx.int32), "mlx.core.int32")
self.assertEqual(str(mx.int64), "mlx.core.int64")
self.assertEqual(str(mx.float16), "mlx.core.float16")
self.assertEqual(str(mx.float32), "mlx.core.float32")
self.assertEqual(str(mx.bfloat16), "mlx.core.bfloat16")
self.assertEqual(str(mx.complex64), "mlx.core.complex64")
def test_scalar_conversion(self):
dtypes = [
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"complex64",
]
for dtype in dtypes:
with self.subTest(dtype=dtype):
x = np.array(2, dtype=getattr(np, dtype))
y = np.min(x)
self.assertEqual(x.dtype, y.dtype)
self.assertTupleEqual(x.shape, y.shape)
z = mx.array(y)
self.assertEqual(np.array(z), x)
self.assertEqual(np.array(z), y)
self.assertEqual(z.dtype, getattr(mx, dtype))
self.assertListEqual(list(z.shape), list(x.shape))
self.assertListEqual(list(z.shape), list(y.shape))
class TestEquality(mlx_tests.MLXTestCase):
def test_array_eq_array(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3])
c = mx.array([1, 2, 4])
self.assertTrue(mx.all(a == b))
self.assertFalse(mx.all(a == c))
def test_array_eq_scalar(self):
a = mx.array([1, 2, 3])
b = 1
c = 4
d = 2.5
e = mx.array([1, 2.5, 3.25])
self.assertTrue(mx.any(a == b))
self.assertFalse(mx.all(a == c))
self.assertFalse(mx.all(a == d))
self.assertTrue(mx.any(a == e))
def test_list_equals_array(self):
a = mx.array([1, 2, 3])
b = [1, 2, 3]
c = [1, 2, 4]
# mlx array equality returns false if is compared with any kind of
# object which is not an mlx array
self.assertFalse(a == b)
self.assertFalse(a == c)
def test_tuple_equals_array(self):
a = mx.array([1, 2, 3])
b = (1, 2, 3)
c = (1, 2, 4)
# mlx array equality returns false if is compared with any kind of
# object which is not an mlx array
self.assertFalse(a == b)
self.assertFalse(a == c)
class TestInequality(mlx_tests.MLXTestCase):
def test_array_ne_array(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3])
c = mx.array([1, 2, 4])
self.assertFalse(mx.any(a != b))
self.assertTrue(mx.any(a != c))
def test_array_ne_scalar(self):
a = mx.array([1, 2, 3])
b = 1
c = 4
d = 1.5
e = 2.5
f = mx.array([1, 2.5, 3.25])
self.assertFalse(mx.all(a != b))
self.assertTrue(mx.any(a != c))
self.assertTrue(mx.any(a != d))
self.assertTrue(mx.any(a != e))
self.assertFalse(mx.all(a != f))
def test_list_not_equals_array(self):
a = mx.array([1, 2, 3])
b = [1, 2, 3]
c = [1, 2, 4]
# mlx array inequality returns true if is compared with any kind of
# object which is not an mlx array
self.assertTrue(a != b)
self.assertTrue(a != c)
def test_dlx_device_type(self):
a = mx.array([1, 2, 3])
device_type, device_id = a.__dlpack_device__()
self.assertIn(device_type, [1, 8])
self.assertEqual(device_id, 0)
if device_type == 8:
# Additional check if Metal is supposed to be available
self.assertTrue(mx.metal.is_available())
elif device_type == 1:
# Additional check if CPU is the fallback
self.assertFalse(mx.metal.is_available())
def test_tuple_not_equals_array(self):
a = mx.array([1, 2, 3])
b = (1, 2, 3)
c = (1, 2, 4)
# mlx array inequality returns true if is compared with any kind of
# object which is not an mlx array
self.assertTrue(a != b)
self.assertTrue(a != c)
def test_obj_inequality_array(self):
str_ = "hello"
a = mx.array([1, 2, 3])
lst_ = [1, 2, 3]
tpl_ = (1, 2, 3)
# check if object comparison(</>/<=/>=) with mlx array should throw an exception
# if not, the tests will fail
with self.assertRaises(ValueError):
a < str_
with self.assertRaises(ValueError):
a > str_
with self.assertRaises(ValueError):
a <= str_
with self.assertRaises(ValueError):
a >= str_
with self.assertRaises(ValueError):
a < lst_
with self.assertRaises(ValueError):
a > lst_
with self.assertRaises(ValueError):
a <= lst_
with self.assertRaises(ValueError):
a >= lst_
with self.assertRaises(ValueError):
a < tpl_
with self.assertRaises(ValueError):
a > tpl_
with self.assertRaises(ValueError):
a <= tpl_
with self.assertRaises(ValueError):
a >= tpl_
def test_invalid_op_on_array(self):
str_ = "hello"
a = mx.array([1, 2.5, 3.25])
lst_ = [1, 2.1, 3.25]
tpl_ = (1, 2.5, 3.25)
with self.assertRaises(ValueError):
a * str_
with self.assertRaises(ValueError):
a *= str_
with self.assertRaises(ValueError):
a /= lst_
with self.assertRaises(ValueError):
a // lst_
with self.assertRaises(ValueError):
a % lst_
with self.assertRaises(ValueError):
a**tpl_
with self.assertRaises(ValueError):
a & tpl_
with self.assertRaises(ValueError):
a | str_
class TestArray(mlx_tests.MLXTestCase):
def test_array_basics(self):
x = mx.array(1)
self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.itemsize, 4)
self.assertEqual(x.nbytes, 4)
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), int))
with self.assertRaises(TypeError):
len(x)
x = mx.array(1, mx.uint32)
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), int))
x = mx.array(1, mx.int64)
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), int))
x = mx.array(1, mx.bfloat16)
self.assertEqual(x.item(), 1.0)
x = mx.array(1.0)
self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.item(), 1.0)
self.assertTrue(isinstance(x.item(), float))
x = mx.array(False)
self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.bool_)
self.assertEqual(x.item(), False)
self.assertTrue(isinstance(x.item(), bool))
x = mx.array(complex(1, 1))
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.complex64)
self.assertEqual(x.item(), complex(1, 1))
self.assertTrue(isinstance(x.item(), complex))
x = mx.array([True, False, True])
self.assertEqual(x.dtype, mx.bool_)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, (3,))
self.assertEqual(len(x), 3)
x = mx.array([True, False, True], mx.float32)
self.assertEqual(x.dtype, mx.float32)
x = mx.array([0, 1, 2])
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, (3,))
x = mx.array([0, 1, 2], mx.float32)
self.assertEqual(x.dtype, mx.float32)
x = mx.array([0.0, 1.0, 2.0])
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, (3,))
x = mx.array([1j, 1 + 0j])
self.assertEqual(x.dtype, mx.complex64)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, (2,))
# From tuple
x = mx.array((1, 2, 3), mx.int32)
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.tolist(), [1, 2, 3])
def test_bool_conversion(self):
x = mx.array(True)
self.assertTrue(x)
x = mx.array(False)
self.assertFalse(x)
x = mx.array(1.0)
self.assertTrue(x)
x = mx.array(0.0)
self.assertFalse(x)
def test_construction_from_lists(self):
x = mx.array([])
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, (0,))
self.assertEqual(x.dtype, mx.float32)
x = mx.array([[], [], []])
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, (3, 0))
self.assertEqual(x.dtype, mx.float32)
x = mx.array([[[], []], [[], []], [[], []]])
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, (3, 2, 0))
self.assertEqual(x.dtype, mx.float32)
# Check failure cases
with self.assertRaises(ValueError):
x = mx.array([[[], []], [[]], [[], []]])
with self.assertRaises(ValueError):
x = mx.array([[[], []], [[1.0, 2.0], []], [[], []]])
with self.assertRaises(ValueError):
x = mx.array([[0, 1], [[0, 1], 1]])
with self.assertRaises(ValueError):
x = mx.array([[0, 1], ["hello", 1]])
x = mx.array([True, False, 3])
self.assertEqual(x.dtype, mx.int32)
x = mx.array([True, False, 3, 4.0])
self.assertEqual(x.dtype, mx.float32)
x = mx.array([[True, False], [1, 3], [2, 4.0]])
self.assertEqual(x.dtype, mx.float32)
x = mx.array([[1.0, 2.0], [0.0, 3.9]], mx.bool_)
self.assertEqual(x.dtype, mx.bool_)
self.assertTrue(mx.array_equal(x, mx.array([[True, True], [False, True]])))
x = mx.array([[1.0, 2.0], [0.0, 3.9]], mx.int32)
self.assertTrue(mx.array_equal(x, mx.array([[1, 2], [0, 3]])))
x = mx.array([1 + 0j, 2j, True, 0], mx.complex64)
self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j])
xnp = np.array([0, 4294967295], dtype=np.uint32)
x = mx.array([0, 4294967295], dtype=mx.uint32)
self.assertTrue(np.array_equal(x, xnp))
xnp = np.array([0, 4294967295], dtype=np.float32)
x = mx.array([0, 4294967295], dtype=mx.float32)
self.assertTrue(np.array_equal(x, xnp))
def test_construction_from_lists_of_mlx_arrays(self):
dtypes = [
mx.bool_,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.float16,
mx.float32,
mx.bfloat16,
mx.complex64,
]
for x_t, y_t in permutations(dtypes, 2):
# check type promotion and numeric correctness
x, y = mx.array([1.0], x_t), mx.array([2.0], y_t)
z = mx.array([x, y])
expected = mx.stack([x, y], axis=0)
self.assertEqualArray(z, expected)
# check heterogeneous construction with mlx arrays and python primitive types
x, y = mx.array([True], x_t), mx.array([False], y_t)
z = mx.array([[x, [2.0]], [[3.0], y]])
expected = mx.array([[[x.item()], [2.0]], [[3.0], [y.item()]]], z.dtype)
self.assertEqualArray(z, expected)
# check when create from an array which does not contain memory to the raw data
x = mx.array([1.0]).astype(mx.bfloat16) # x does not hold raw data
for y_t in dtypes:
y = mx.array([2.0], y_t)
z = mx.array([x, y])
expected = mx.stack([x, y], axis=0)
self.assertEqualArray(z, expected)
# shape check from `stack()`
with self.assertRaises(ValueError) as e:
mx.array([x, 1.0])
self.assertEqual(
str(e.exception), "Initialization encountered non-uniform length."
)
# shape check from `validate_shape`
with self.assertRaises(ValueError) as e:
mx.array([1.0, x])
self.assertEqual(
str(e.exception), "Initialization encountered non-uniform length."
)
# check that `[mx.array, ...]` retains the `mx.array` in the graph
def f(x):
y = mx.array([x, mx.array([2.0])])
return (2 * y).sum()
x = mx.array([1.0])
dfdx = mx.grad(f)
self.assertEqual(dfdx(x).item(), 2.0)
def test_init_from_array(self):
x = mx.array(3.0)
y = mx.array(x)
self.assertTrue(mx.array_equal(x, y))
y = mx.array(x, mx.int32)
self.assertEqual(y.dtype, mx.int32)
self.assertEqual(y.item(), 3)
y = mx.array(x, mx.bool_)
self.assertEqual(y.dtype, mx.bool_)
self.assertEqual(y.item(), True)
y = mx.array(x, mx.complex64)
self.assertEqual(y.dtype, mx.complex64)
self.assertEqual(y.item(), 3.0 + 0j)
def test_array_repr(self):
x = mx.array(True)
self.assertEqual(str(x), "array(True, dtype=bool)")
x = mx.array(1)
self.assertEqual(str(x), "array(1, dtype=int32)")
x = mx.array(1.0)
self.assertEqual(str(x), "array(1, dtype=float32)")
x = mx.array([1, 0, 1])
self.assertEqual(str(x), "array([1, 0, 1], dtype=int32)")
x = mx.array([1] * 6)
expected = "array([1, 1, 1, 1, 1, 1], dtype=int32)"
self.assertEqual(str(x), expected)
x = mx.array([1] * 7)
expected = "array([1, 1, 1, ..., 1, 1, 1], dtype=int32)"
self.assertEqual(str(x), expected)
x = mx.array([[1, 2], [1, 2], [1, 2]])
expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)"
self.assertEqual(str(x), expected)
x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])
expected = (
"array([[[1, 2],\n"
" [1, 2]],\n"
" [[1, 2],\n"
" [1, 2]]], dtype=int32)"
)
self.assertEqual(str(x), expected)
x = mx.array([[1, 2]] * 6)
expected = (
"array([[1, 2],\n"
" [1, 2],\n"
" [1, 2],\n"
" [1, 2],\n"
" [1, 2],\n"
" [1, 2]], dtype=int32)"
)
self.assertEqual(str(x), expected)
x = mx.array([[1, 2]] * 7)
expected = (
"array([[1, 2],\n"
" [1, 2],\n"
" [1, 2],\n"
" ...,\n"
" [1, 2],\n"
" [1, 2],\n"
" [1, 2]], dtype=int32)"
)
self.assertEqual(str(x), expected)
x = mx.array([1], dtype=mx.int8)
expected = "array([1], dtype=int8)"
self.assertEqual(str(x), expected)
x = mx.array([1], dtype=mx.int16)
expected = "array([1], dtype=int16)"
self.assertEqual(str(x), expected)
x = mx.array([1], dtype=mx.uint8)
expected = "array([1], dtype=uint8)"
self.assertEqual(str(x), expected)
# Fp16 is not supported in all platforms
x = mx.array([1.2], dtype=mx.float16)
expected = "array([1.2002], dtype=float16)"
self.assertEqual(str(x), expected)
x = mx.array([1 + 1j], dtype=mx.complex64)
expected = "array([1+1j], dtype=complex64)"
self.assertEqual(str(x), expected)
x = mx.array([1 - 1j], dtype=mx.complex64)
expected = "array([1-1j], dtype=complex64)"
x = mx.array([1 + 1j], dtype=mx.complex64)
expected = "array([1+1j], dtype=complex64)"
self.assertEqual(str(x), expected)
x = mx.array([1 - 1j], dtype=mx.complex64)
expected = "array([1-1j], dtype=complex64)"
def test_array_to_list(self):
types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32]
for t in types:
x = mx.array(1, t)
self.assertEqual(x.tolist(), 1)
vals = [1, 2, 3, 4]
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
vals = [[1, 2], [3, 4]]
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
vals = [[1, 0], [0, 1]]
x = mx.array(vals, mx.bool_)
self.assertEqual(x.tolist(), vals)
vals = [[1.5, 2.5], [3.5, 4.5]]
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
vals = [[[0.5, 1.5], [2.5, 3.5]], [[4.5, 5.5], [6.5, 7.5]]]
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
# Empty arrays
vals = []
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
vals = [[], []]
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
# Complex arrays
vals = [0.5 + 0j, 1.5 + 1j, 2.5 + 0j, 3.5 + 1j]
x = mx.array(vals)
self.assertEqual(x.tolist(), vals)
# Half types
vals = [1.0, 2.0, 3.0, 4.0, 5.0]
x = mx.array(vals, dtype=mx.float16)
self.assertEqual(x.tolist(), vals)
x = mx.array(vals, dtype=mx.bfloat16)
self.assertEqual(x.tolist(), vals)
def test_array_np_conversion(self):
# Shape test
a = np.array([])
x = mx.array(a)
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, (0,))
self.assertEqual(x.dtype, mx.float32)
a = np.array([[], [], []])
x = mx.array(a)
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, (3, 0))
self.assertEqual(x.dtype, mx.float32)
a = np.array([[[], []], [[], []], [[], []]])
x = mx.array(a)
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, (3, 2, 0))
self.assertEqual(x.dtype, mx.float32)
# Content test
a = 2.0 * np.ones((3, 5, 4))
x = mx.array(a)
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.ndim, 3)
self.assertEqual(x.shape, (3, 5, 4))
y = np.asarray(x)
self.assertTrue(np.allclose(a, y))
a = np.array(3, dtype=np.int32)
x = mx.array(a)
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, ())
self.assertEqual(x.item(), 3)
# mlx to numpy test
x = mx.array([True, False, True])
y = np.asarray(x)
self.assertEqual(y.dtype, np.bool_)
self.assertEqual(y.ndim, 1)
self.assertEqual(y.shape, (3,))
self.assertEqual(y[0], True)
self.assertEqual(y[1], False)
self.assertEqual(y[2], True)
# complex64 mx <-> np
cvals = [0j, 1, 1 + 1j]
x = np.array(cvals)
y = mx.array(x)
self.assertEqual(y.dtype, mx.complex64)
self.assertEqual(y.shape, (3,))
self.assertEqual(y.tolist(), cvals)
y = mx.array([0j, 1, 1 + 1j])
x = np.asarray(y)
self.assertEqual(x.dtype, np.complex64)
self.assertEqual(x.shape, (3,))
self.assertEqual(x.tolist(), cvals)
def test_array_np_dtype_conversion(self):
dtypes_list = [
(mx.bool_, np.bool_),
(mx.uint8, np.uint8),
(mx.uint16, np.uint16),
(mx.uint32, np.uint32),
(mx.uint64, np.uint64),
(mx.int8, np.int8),
(mx.int16, np.int16),
(mx.int32, np.int32),
(mx.int64, np.int64),
(mx.float16, np.float16),
(mx.float32, np.float32),
(mx.complex64, np.complex64),
]
for mlx_dtype, np_dtype in dtypes_list:
a_npy = np.random.uniform(low=0, high=100, size=(32,)).astype(np_dtype)
a_mlx = mx.array(a_npy)
self.assertEqual(a_mlx.dtype, mlx_dtype)
self.assertTrue(np.allclose(a_mlx, a_npy))
b_mlx = mx.random.uniform(
low=0,
high=10,
shape=(32,),
).astype(mlx_dtype)
b_npy = np.array(b_mlx)
self.assertEqual(b_npy.dtype, np_dtype)
def test_array_from_noncontiguous_np(self):
for t in [np.int8, np.int32, np.float16, np.float32, np.complex64]:
np_arr = np.random.uniform(size=(10, 10)).astype(np.complex64)
np_arr = np_arr.T
mx_arr = mx.array(np_arr)
self.assertTrue(mx.array_equal(np_arr, mx_arr))
def test_array_np_shape_dim_check(self):
a_npy = np.empty(2**31, dtype=np.bool_)
with self.assertRaises(ValueError) as e:
mx.array(a_npy)
self.assertEqual(
str(e.exception), "Shape dimension falls outside supported `int` range."
)
def test_dtype_promotion(self):
dtypes_list = [
(mx.bool_, np.bool_),
(mx.uint8, np.uint8),
(mx.uint16, np.uint16),
(mx.uint32, np.uint32),
(mx.uint64, np.uint64),
(mx.int8, np.int8),
(mx.int16, np.int16),
(mx.int32, np.int32),
(mx.int64, np.int64),
(mx.float32, np.float32),
]
promotion_pairs = permutations(dtypes_list, 2)
for (mlx_dt_1, np_dt_1), (mlx_dt_2, np_dt_2) in promotion_pairs:
with self.subTest(dtype1=np_dt_1, dtype2=np_dt_2):
a_npy = np.ones((3,), dtype=np_dt_1)
b_npy = np.ones((3,), dtype=np_dt_2)
c_npy = a_npy + b_npy
a_mlx = mx.ones((3,), dtype=mlx_dt_1)
b_mlx = mx.ones((3,), dtype=mlx_dt_2)
c_mlx = a_mlx + b_mlx
self.assertEqual(c_mlx.dtype, mx.array(c_npy).dtype)
a_mlx = mx.ones((3,), dtype=mx.float16)
b_mlx = mx.ones((3,), dtype=mx.float32)
c_mlx = a_mlx + b_mlx
self.assertEqual(c_mlx.dtype, mx.float32)
b_mlx = mx.ones((3,), dtype=mx.int32)
c_mlx = a_mlx + b_mlx
self.assertEqual(c_mlx.dtype, mx.float16)
def test_dtype_python_scalar_promotion(self):
tests = [
(mx.bool_, operator.mul, False, mx.bool_),
(mx.bool_, operator.mul, 0, mx.int32),
(mx.bool_, operator.mul, 1.0, mx.float32),
(mx.int8, operator.mul, False, mx.int8),
(mx.int8, operator.mul, 0, mx.int8),
(mx.int8, operator.mul, 1.0, mx.float32),
(mx.int16, operator.mul, False, mx.int16),
(mx.int16, operator.mul, 0, mx.int16),
(mx.int16, operator.mul, 1.0, mx.float32),
(mx.int32, operator.mul, False, mx.int32),
(mx.int32, operator.mul, 0, mx.int32),
(mx.int32, operator.mul, 1.0, mx.float32),
(mx.int64, operator.mul, False, mx.int64),
(mx.int64, operator.mul, 0, mx.int64),
(mx.int64, operator.mul, 1.0, mx.float32),
(mx.uint8, operator.mul, False, mx.uint8),
(mx.uint8, operator.mul, 0, mx.uint8),
(mx.uint8, operator.mul, 1.0, mx.float32),
(mx.uint16, operator.mul, False, mx.uint16),
(mx.uint16, operator.mul, 0, mx.uint16),
(mx.uint16, operator.mul, 1.0, mx.float32),
(mx.uint32, operator.mul, False, mx.uint32),
(mx.uint32, operator.mul, 0, mx.uint32),
(mx.uint32, operator.mul, 1.0, mx.float32),
(mx.uint64, operator.mul, False, mx.uint64),
(mx.uint64, operator.mul, 0, mx.uint64),
(mx.uint64, operator.mul, 1.0, mx.float32),
(mx.float32, operator.mul, False, mx.float32),
(mx.float32, operator.mul, 0, mx.float32),
(mx.float32, operator.mul, 1.0, mx.float32),
(mx.float16, operator.mul, False, mx.float16),
(mx.float16, operator.mul, 0, mx.float16),
(mx.float16, operator.mul, 1.0, mx.float16),
]
for dtype_in, f, v, dtype_out in tests:
x = mx.array(0, dtype_in)
y = f(x, v)
self.assertEqual(y.dtype, dtype_out)
def test_array_comparison(self):
a = mx.array([0.0, 1.0, 5.0])
b = mx.array([-1.0, 2.0, 5.0])
self.assertEqual((a < b).tolist(), [False, True, False])
self.assertEqual((a <= b).tolist(), [False, True, True])
self.assertEqual((a > b).tolist(), [True, False, False])
self.assertEqual((a >= b).tolist(), [True, False, True])
self.assertEqual((a < 5).tolist(), [True, True, False])
self.assertEqual((5 < a).tolist(), [False, False, False])
self.assertEqual((5 <= a).tolist(), [False, False, True])
self.assertEqual((a > 1).tolist(), [False, False, True])
self.assertEqual((a >= 1).tolist(), [False, True, True])
def test_array_neg(self):
a = mx.array([-1.0, 4.0, 0.0])
self.assertEqual((-a).tolist(), [1.0, -4.0, 0.0])
def test_array_type_cast(self):
a = mx.array([0.1, 2.3, -1.3])
b = [0, 2, -1]
self.assertEqual(a.astype(mx.int32).tolist(), b)
self.assertEqual(a.astype(mx.int32).dtype, mx.int32)
b = mx.array(b).astype(mx.float32)
self.assertEqual(b.dtype, mx.float32)
def test_array_iteration(self):
a = mx.array([0, 1, 2])
for i, x in enumerate(a):
self.assertEqual(x.item(), i)
a = mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
x, y, z = a
self.assertEqual(x.tolist(), [1.0, 2.0])
self.assertEqual(y.tolist(), [3.0, 4.0])
self.assertEqual(z.tolist(), [5.0, 6.0])
def test_array_pickle(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.float16,
mx.float32,
mx.complex64,
]
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
state = pickle.dumps(x)
y = pickle.loads(state)
self.assertEqualArray(y, x)
# check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(TypeError):
pickle.dumps(x)
def test_array_copy(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.float16,
mx.float32,
mx.bfloat16,
mx.complex64,
]
for copy_function in [copy, deepcopy]:
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
y = copy_function(x)
self.assertEqualArray(y, x)
y -= 1
self.assertEqualArray(y, x - 1)
def test_indexing(self):
# Only ellipsis is a no-op
a_mlx = mx.array([1])[...]
self.assertEqual(a_mlx.shape, (1,))
self.assertEqual(a_mlx.item(), 1)
# Basic content check, slice indexing
a_npy = np.arange(64, dtype=np.float32)
a_mlx = mx.array(a_npy)
a_sliced_mlx = a_mlx[2:50:4]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[2:50:4]))
# Basic content check, mlx array indexing
a_npy = np.arange(64, dtype=np.int32)
a_npy = a_npy.reshape((8, 8))
a_mlx = mx.array(a_npy)
idx_npy = np.array([0, 1, 2, 7, 5], dtype=np.uint32)
idx_mlx = mx.array(idx_npy)
a_sliced_mlx = a_mlx[idx_mlx]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[idx_npy]))
# Basic content check, int indexing
a_sliced_mlx = a_mlx[5]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[5]))
self.assertEqual(len(a_sliced_npy.shape), len(a_npy[5].shape))
self.assertEqual(len(a_sliced_npy.shape), 1)
self.assertEqual(a_sliced_npy.shape[0], a_npy[5].shape[0])
# Basic content check, negative indexing
a_sliced_mlx = a_mlx[-1]
self.assertTrue(np.array_equal(a_sliced_mlx, a_npy[-1]))
# Basic content check, empty index
a_sliced_mlx = a_mlx[()]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[()]))
# Basic content check, new axis
a_sliced_mlx = a_mlx[None]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[None]))
a_sliced_mlx = a_mlx[:, None]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, None]))
# Multi dim indexing, all ints
self.assertEqual(a_mlx[0, 0].item(), 0)
self.assertEqual(a_mlx[0, 0].ndim, 0)
# Multi dim indexing, all slices
a_sliced_mlx = a_mlx[2:4, 5:]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[2:4, 5:]))
a_sliced_mlx = a_mlx[:, 0:5]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, 0:5]))
# Slicing, strides
a_sliced_mlx = a_mlx[:, ::2]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, ::2]))
# Slicing, -ve index
a_sliced_mlx = a_mlx[-2:, :-1]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[-2:, :-1]))
# Slicing, start > end
a_sliced_mlx = a_mlx[8:3]
self.assertEqual(a_sliced_mlx.size, 0)
# Slicing, Clipping past the end
a_sliced_mlx = a_mlx[7:10]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[7:10]))
# Multi dim indexing, int and slices
a_sliced_mlx = a_mlx[0, :5]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[0, :5]))
a_sliced_mlx = a_mlx[:, -1]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, -1]))
# Multi dim indexing, int and array
a_sliced_mlx = a_mlx[idx_mlx, 0]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[idx_npy, 0]))
# Multi dim indexing, array and slices
a_sliced_mlx = a_mlx[idx_mlx, :5]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[idx_npy, :5]))
a_sliced_mlx = a_mlx[:, idx_mlx]
a_sliced_npy = np.asarray(a_sliced_mlx)
self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, idx_npy]))
# Multi dim indexing with multiple arrays
def check_slices(arr_np, *idx_np):
arr_mlx = mx.array(arr_np)
idx_mlx = [
mx.array(idx) if isinstance(idx, np.ndarray) else idx for idx in idx_np
]
slice_mlx = arr_mlx[tuple(idx_mlx)]
self.assertTrue(
np.array_equal(arr_np[tuple(idx_np)], arr_mlx[tuple(idx_mlx)])
)
a_np = np.arange(16).reshape(4, 4)
check_slices(a_np, np.array([0, 1, 2, 3]), np.array([0, 1, 2, 3]))
check_slices(a_np, np.array([0, 1, 2, 3]), np.array([1, 0, 3, 3]))
check_slices(a_np, np.array([[0, 1]]), np.array([[0], [1], [3]]))
a_np = np.arange(64).reshape(2, 4, 2, 4)
check_slices(a_np, 0, np.array([0, 1, 2]))
check_slices(a_np, slice(0, 1), np.array([0, 1, 2]))
check_slices(
a_np, slice(0, 1), np.array([0, 1, 2]), slice(None), slice(0, 4, 2)
)
check_slices(
a_np, slice(0, 1), np.array([0, 1, 2]), slice(None), np.array([1, 2, 0])
)
check_slices(a_np, slice(0, 1), np.array([0, 1, 2]), 1, np.array([1, 2, 0]))
check_slices(
a_np, slice(0, 1), np.array([0, 1, 2]), np.array([1, 0, 0]), slice(0, 1)
)
check_slices(
a_np,
slice(0, 1),
np.array([[0], [1], [2]]),
np.array([[1, 0, 0]]),
slice(0, 1),
)
check_slices(
a_np,
slice(0, 2),
np.array([[0], [1], [2]]),
slice(0, 2),
np.array([[1, 0, 0]]),
)
for p in permutations([slice(None), slice(None), 0, np.array([1, 0])]):
check_slices(a_np, *p)
for p in permutations(
[slice(None), slice(None), 0, np.array([1, 0]), None, None]
):
check_slices(a_np, *p)
for p in permutations([0, np.array([1, 0]), None, Ellipsis, slice(None)]):
check_slices(a_np, *p)
# Non-contiguous arrays in slicing
a_mlx = mx.reshape(mx.arange(128), (16, 8))
a_mlx = a_mlx[::2, :]
a_np = np.array(a_mlx)
idx_np = np.arange(8)[::2]
idx_mlx = mx.arange(8)[::2]
self.assertTrue(
np.array_equal(a_np[idx_np, idx_np], np.array(a_mlx[idx_mlx, idx_mlx]))
)
# Slicing with negative indices and integer
a_np = np.arange(10).reshape(5, 2)
a_mlx = mx.array(a_np)
self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0])))
def test_indexing_grad(self):
x = mx.array([[1, 2], [3, 4]]).astype(mx.float32)
ind = mx.array([0, 1, 0]).astype(mx.float32)
def index_fn(x, ind):
return x[ind.astype(mx.int32)].sum()
grad_x, grad_ind = mx.grad(index_fn, argnums=(0, 1))(x, ind)
expected = mx.array([[2, 2], [1, 1]])
self.assertTrue(mx.array_equal(grad_x, expected))
self.assertTrue(mx.array_equal(grad_ind, mx.zeros(ind.shape)))
def test_setitem(self):
a = mx.array(0)
a[None] = 1
self.assertEqual(a.item(), 1)
a = mx.array([1, 2, 3])
a[0] = 2
self.assertEqual(a.tolist(), [2, 2, 3])
a[-1] = 2
self.assertEqual(a.tolist(), [2, 2, 2])
a[0] = mx.array([[[1]]])
self.assertEqual(a.tolist(), [1, 2, 2])
a[:] = 0
self.assertEqual(a.tolist(), [0, 0, 0])
a[None] = 1
self.assertEqual(a.tolist(), [1, 1, 1])
a[0:1] = 2
self.assertEqual(a.tolist(), [2, 1, 1])
a[0:2] = 3
self.assertEqual(a.tolist(), [3, 3, 1])
a[0:3] = 4
self.assertEqual(a.tolist(), [4, 4, 4])
a[0:1] = mx.array(0)
self.assertEqual(a.tolist(), [0, 4, 4])
a[0:1] = mx.array([1])
self.assertEqual(a.tolist(), [1, 4, 4])
with self.assertRaises(ValueError):
a[0:1] = mx.array([2, 3])
a[0:2] = mx.array([2, 2])
self.assertEqual(a.tolist(), [2, 2, 4])
a[:] = mx.array([[[[1, 1, 1]]]])
self.assertEqual(a.tolist(), [1, 1, 1])
# Array slices
def check_slices(arr_np, update_np, *idx_np):
arr_mlx = mx.array(arr_np)
update_mlx = mx.array(update_np)
idx_mlx = [
mx.array(idx) if isinstance(idx, np.ndarray) else idx for idx in idx_np
]
if len(idx_np) > 1:
idx_np = tuple(idx_np)
idx_mlx = tuple(idx_mlx)
else:
idx_np = idx_np[0]
idx_mlx = idx_mlx[0]
arr_np[idx_np] = update_np
arr_mlx[idx_mlx] = update_mlx
self.assertTrue(np.array_equal(arr_np, arr_mlx))
check_slices(np.zeros((3, 3)), 1, 0)
check_slices(np.zeros((3, 3)), 1, -1)
check_slices(np.zeros((3, 3)), 1, slice(0, 2))
check_slices(np.zeros((3, 3)), np.array([[0, 1, 2], [3, 4, 5]]), slice(0, 2))
with self.assertRaises(ValueError):
a = mx.array(0)
a[0] = mx.array(1)
check_slices(np.zeros((3, 3)), 1, np.array([0, 1, 2]))
check_slices(np.zeros((3, 3)), np.array(3), np.array([0, 1, 2]))
check_slices(np.zeros((3, 3)), np.array([3]), np.array([0, 1, 2]))
check_slices(np.zeros((3, 3)), np.array([3]), np.array([0, 1]))
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices(
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
)
# Multiple slices
a = mx.array(0)
a[None, None] = 1
self.assertEqual(a.item(), 1)
a[None, None] = mx.array(2)
self.assertEqual(a.item(), 2)
a[None, None] = mx.array([[[3]]])
self.assertEqual(a.item(), 3)
a[()] = 4
self.assertEqual(a.item(), 4)
a_np = np.zeros((2, 3, 4, 5))
check_slices(a_np, 1, np.array([0, 0]), slice(0, 2), slice(0, 3), 4)
check_slices(
a_np,
np.arange(10).reshape(2, 5),
np.array([0, 0]),
np.array([0, 1]),
np.array([2, 3]),
)
check_slices(
a_np,
np.array([[3], [4]]),
np.array([0, 0]),
np.array([0, 1]),
np.array([2, 3]),
)
check_slices(
a_np, np.arange(5), np.array([0, 0]), np.array([0, 1]), np.array([2, 3])
)
check_slices(np.zeros(5), np.arange(2), None, None, np.array([2, 3]))
check_slices(
np.zeros((4, 3, 4)),
np.arange(3),
np.array([2, 3]),
slice(0, 3),
np.array([2, 3]),
)
with self.assertRaises(ValueError):
a = mx.zeros((4, 3, 4))
a[mx.array([2, 3]), None, mx.array([2, 3])] = mx.arange(2)
with self.assertRaises(ValueError):
a = mx.zeros((4, 3, 4))
a[mx.array([2, 3]), None, mx.array([2, 3])] = mx.arange(3)
check_slices(np.zeros((4, 3, 4)), 1, np.array([2, 3]), None, np.array([2, 1]))
check_slices(
np.zeros((4, 3, 4)), np.arange(4), np.array([2, 3]), None, np.array([2, 1])
)
check_slices(
np.zeros((4, 3, 4)),
np.arange(2 * 4).reshape(2, 1, 4),
np.array([2, 3]),
None,
np.array([2, 1]),
)
check_slices(np.zeros((4, 4)), 1, slice(0, 2), slice(0, 2))
check_slices(np.zeros((4, 4)), np.arange(2), slice(0, 2), slice(0, 2))
check_slices(
np.zeros((4, 4)), np.arange(2).reshape(2, 1), slice(0, 2), slice(0, 2)
)
check_slices(
np.zeros((4, 4)), np.arange(4).reshape(2, 2), slice(0, 2), slice(0, 2)
)
with self.assertRaises(ValueError):
a = mx.zeros((2, 2, 2))
a[..., ...] = 1
with self.assertRaises(ValueError):
a = mx.zeros((2, 2, 2, 2, 2))
a[0, ..., 0, ..., 0] = 1
with self.assertRaises(ValueError):
a = mx.zeros((2, 2))
a[0, 0, 0] = 1
with self.assertRaises(ValueError):
a = mx.zeros((5, 4, 3))
a[:, 0] = mx.ones((5, 1, 3))
check_slices(np.zeros((2, 2, 2, 2)), 1, None, Ellipsis, None)
check_slices(
np.zeros((2, 2, 2, 2)), 1, np.array([0, 1]), Ellipsis, np.array([0, 1])
)
check_slices(
np.zeros((2, 2, 2, 2)),
np.arange(2 * 2 * 2).reshape(2, 2, 2),
np.array([0, 1]),
Ellipsis,
np.array([0, 1]),
)
# Check slice assign with negative indices works
a = mx.zeros((5, 5), mx.int32)
a[2:-2, 2:-2] = 4
self.assertEqual(a[2, 2].item(), 4)
# Check slice array slice
check_slices(
np.zeros((5, 4, 4)),
np.arange(4 * 2 * 3).reshape(4, 2, 3),
slice(0, 4),
np.array([1, 3]),
slice(None, -1),
)
check_slices(
np.zeros((5, 4, 4)),
np.arange(4 * 2 * 2).reshape(4, 2, 2),
slice(0, 4),
np.array([1, 3]),
slice(0, 4, 2),
)
check_slices(
np.zeros((1, 10, 4)),
np.arange(2 * 4).reshape(1, 2, 4),
slice(None, None, None),
np.array([1, 3]),
)
check_slices(
np.zeros((3, 4, 5, 3)),
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
np.array([2, 1]),
slice(None, None, None),
slice(None, None, 2),
slice(None, None, None),
)
check_slices(
np.zeros((3, 4, 5, 3)),
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
np.array([2, 1]),
slice(None, None, None),
slice(None, None, 2),
)
check_slices(np.zeros((5, 4, 3)), np.ones((5, 3)), slice(None), 0)
check_slices(np.zeros((5, 4, 3)), np.ones((5, 1, 3)), slice(None), slice(0, 1))
check_slices(
np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4)
)
x = mx.zeros((2, 3, 4, 5, 3))
x[..., 0] = 1.0
self.assertTrue(mx.array_equal(x[..., 0], mx.ones((2, 3, 4, 5))))
x = mx.zeros((2, 3, 4, 5, 3))
x[:, 0] = 1.0
self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3))))
x = mx.zeros((2, 2, 2, 2, 2, 2))
x[0, 0] = 1
self.assertTrue(mx.array_equal(x[0, 0], mx.ones((2, 2, 2, 2))))
def test_array_at(self):
a = mx.array(1)
a = a.at[None].add(1)
self.assertEqual(a.item(), 2)
a = mx.array([0, 1, 2])
a = a.at[1].add(2)
self.assertEqual(a.tolist(), [0, 3, 2])
a = a.at[mx.array([0, 0, 0, 0])].add(1)
self.assertEqual(a.tolist(), [4, 3, 2])
a = mx.zeros((10, 10))
a = a.at[0].add(mx.arange(10))
self.assertEqual(a[0].tolist(), list(range(10)))
a = mx.zeros((10, 10))
index_x = mx.array([0, 2, 3, 7])
index_y = mx.array([3, 3, 1, 2])
u = mx.random.uniform(shape=(4,))
a = a.at[index_x, index_y].add(u)
self.assertTrue(mx.allclose(a.sum(), u.sum()))
self.assertEqualArray(a.sum(), u.sum(), atol=1e-6, rtol=1e-5)
self.assertEqual(a[index_x, index_y].tolist(), u.tolist())
# Test all array.at ops
a = mx.random.uniform(shape=(10, 5, 2))
idx_x = mx.array([0, 4])
update = mx.ones((2, 5))
a[idx_x, :, 0] = 0
a = a.at[idx_x, :, 0].add(update)
self.assertEqualArray(a[idx_x, :, 0], update)
a = a.at[idx_x, :, 0].subtract(update)
self.assertEqualArray(a[idx_x, :, 0], mx.zeros_like(update))
a = a.at[idx_x, :, 0].add(2 * update)
self.assertEqualArray(a[idx_x, :, 0], 2 * update)
a = a.at[idx_x, :, 0].multiply(2 * update)
self.assertEqualArray(a[idx_x, :, 0], 4 * update)
a = a.at[idx_x, :, 0].divide(3 * update)
self.assertEqualArray(a[idx_x, :, 0], (4 / 3) * update)
a[idx_x, :, 0] = 5
update = mx.arange(10).reshape(2, 5)
a = a.at[idx_x, :, 0].maximum(update)
self.assertEqualArray(a[idx_x, :, 0], mx.maximum(a[idx_x, :, 0], update))
a[idx_x, :, 0] = 5
a = a.at[idx_x, :, 0].minimum(update)
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
update = mx.array([1.0, 2.0])[None, None, None]
src = mx.array([1.0, 2.0])[None, :]
src = src.at[0:1].add(update)
self.assertTrue(mx.array_equal(src, mx.array([[2.0, 4.0]])))
def test_slice_negative_step(self):
a_np = np.arange(20)
a_mx = mx.array(a_np)
# Basic negative slice
b_np = a_np[::-1]
b_mx = a_mx[::-1]
self.assertTrue(np.array_equal(b_np, b_mx))
# Bounds negative slice
b_np = a_np[-3:3:-1]
b_mx = a_mx[-3:3:-1]
self.assertTrue(np.array_equal(b_np, b_mx))
# Bounds negative slice
b_np = a_np[25:-50:-1]
b_mx = a_mx[25:-50:-1]
self.assertTrue(np.array_equal(b_np, b_mx))
# Jumping negative slice
b_np = a_np[::-3]
b_mx = a_mx[::-3]
self.assertTrue(np.array_equal(b_np, b_mx))
# Bounds and negative slice
b_np = a_np[-3:3:-3]
b_mx = a_mx[-3:3:-3]
self.assertTrue(np.array_equal(b_np, b_mx))
# Bounds and negative slice
b_np = a_np[25:-50:-3]
b_mx = a_mx[25:-50:-3]
self.assertTrue(np.array_equal(b_np, b_mx))
# Negative slice and ascending bounds
b_np = a_np[0:20:-3]
b_mx = a_mx[0:20:-3]
self.assertTrue(np.array_equal(b_np, b_mx))
# Multi-dim negative slices
a_np = np.arange(3 * 6 * 4).reshape(3, 6, 4)
a_mx = mx.array(a_np)
# Flip each dim
b_np = a_np[..., ::-1]
b_mx = a_mx[..., ::-1]
self.assertTrue(np.array_equal(b_np, b_mx))
b_np = a_np[:, ::-1, :]
b_mx = a_mx[:, ::-1, :]
self.assertTrue(np.array_equal(b_np, b_mx))
b_np = a_np[::-1, ...]
b_mx = a_mx[::-1, ...]
self.assertTrue(np.array_equal(b_np, b_mx))
# Flip pairs of dims
b_np = a_np[::-1, 1:5:2, ::-2]
b_mx = a_mx[::-1, 1:5:2, ::-2]
self.assertTrue(np.array_equal(b_np, b_mx))
b_np = a_np[::-1, ::-2, 1:5:2]
b_mx = a_mx[::-1, ::-2, 1:5:2]
self.assertTrue(np.array_equal(b_np, b_mx))
# Flip all dims
b_np = a_np[::-1, ::-3, ::-2]
b_mx = a_mx[::-1, ::-3, ::-2]
self.assertTrue(np.array_equal(b_np, b_mx))
def test_api(self):
x = mx.array(np.random.rand(10, 10, 10))
ops = [
("reshape", (100, -1)),
"square",
"sqrt",
"rsqrt",
"reciprocal",
"exp",
"log",
"sin",
"cos",
"log1p",
"abs",
"log10",
"log2",
"conj",
("all", 1),
("any", 1),
("transpose", (0, 2, 1)),
("sum", 1),
("prod", 1),
("min", 1),
("max", 1),
("logsumexp", 1),
("mean", 1),
("var", 1),
("argmin", 1),
("argmax", 1),
("cummax", 1),
("cummin", 1),
("cumprod", 1),
("cumsum", 1),
("diagonal", 0, 0, 1),
("flatten", 0, -1),
("moveaxis", 1, 2),
("round", 2),
("std", 1, True, 0),
("swapaxes", 1, 2),
]
for op in ops:
if isinstance(op, tuple):
op, *args = op
else:
args = tuple()
y1 = getattr(mx, op)(x, *args)
y2 = getattr(x, op)(*args)
self.assertEqual(y1.dtype, y2.dtype)
self.assertEqual(y1.shape, y2.shape)
self.assertTrue(mx.array_equal(y1, y2))
y1 = mx.split(x, 2)
y2 = x.split(2)
self.assertEqual(len(y1), 2)
self.assertEqual(len(y1), len(y2))
self.assertTrue(mx.array_equal(y1[0], y2[0]))
self.assertTrue(mx.array_equal(y1[1], y2[1]))
x = mx.array(np.random.rand(10, 10, 1))
y1 = mx.squeeze(x, axis=2)
y2 = x.squeeze(axis=2)
self.assertEqual(y1.shape, y2.shape)
self.assertTrue(mx.array_equal(y1, y2))
def test_memoryless_copy(self):
a_mx = mx.ones((2, 2))
b_mx = mx.broadcast_to(a_mx, (5, 2, 2))
# Make np arrays without copy
a_np = np.array(a_mx, copy=False)
b_np = np.array(b_mx, copy=False)
# Check that we get read-only array that does not own the underlying data
self.assertFalse(a_np.flags.owndata)
self.assertTrue(a_np.flags.writeable)
# Check contents
self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np))
self.assertTrue(np.array_equal(np.ones((5, 2, 2), dtype=np.float32), b_np))
# Check strides
self.assertSequenceEqual(b_np.strides, (0, 8, 4))
def test_np_array_conversion_copies_by_default(self):
a_mx = mx.ones((2, 2))
a_np = np.array(a_mx)
self.assertTrue(a_np.flags.owndata)
self.assertTrue(a_np.flags.writeable)
def test_buffer_protocol(self):
dtypes_list = [
(mx.bool_, np.bool_, None),
(mx.uint8, np.uint8, np.iinfo),
(mx.uint16, np.uint16, np.iinfo),
(mx.uint32, np.uint32, np.iinfo),
(mx.uint64, np.uint64, np.iinfo),
(mx.int8, np.int8, np.iinfo),
(mx.int16, np.int16, np.iinfo),
(mx.int32, np.int32, np.iinfo),
(mx.int64, np.int64, np.iinfo),
(mx.float16, np.float16, np.finfo),
(mx.float32, np.float32, np.finfo),
(mx.complex64, np.complex64, np.finfo),
]
for mlx_dtype, np_dtype, info_fn in dtypes_list:
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
if info_fn is not None:
info = info_fn(np_dtype)
a_np[0, 0] = info.min
a_np[0, 1] = info.max
a_mx = mx.array(a_np)
for f in [lambda x: x, lambda x: x.T]:
mv_mx = memoryview(f(a_mx))
mv_np = memoryview(f(a_np))
self.assertEqual(mv_mx.strides, mv_np.strides, f"{mlx_dtype}{np_dtype}")
self.assertEqual(mv_mx.shape, mv_np.shape, f"{mlx_dtype}{np_dtype}")
# correct buffer format for 8 byte (unsigned) 'long long' is Q/q, see
# https://docs.python.org/3.10/library/struct.html#format-characters
# numpy returns L/l, as 'long' is equivalent to 'long long' on 64bit machines, so q and l are equivalent
# see https://github.com/pybind/pybind11/issues/1908
if np_dtype == np.uint64:
self.assertEqual(mv_mx.format, "Q", f"{mlx_dtype}{np_dtype}")
elif np_dtype == np.int64:
self.assertEqual(mv_mx.format, "q", f"{mlx_dtype}{np_dtype}")
else:
self.assertEqual(
mv_mx.format, mv_np.format, f"{mlx_dtype}{np_dtype}"
)
self.assertFalse(mv_mx.readonly)
back_to_npy = np.array(mv_mx, copy=False)
self.assertEqualArray(
back_to_npy,
f(a_np),
atol=0,
rtol=0,
)
# extra test for bfloat16, which is not numpy convertible
a_mx = mx.random.uniform(low=0, high=100, shape=(3, 4), dtype=mx.bfloat16)
mv_mx = memoryview(a_mx)
self.assertEqual(mv_mx.strides, (8, 2))
self.assertEqual(mv_mx.shape, (3, 4))
self.assertEqual(mv_mx.format, "B")
with self.assertRaises(RuntimeError) as cm:
np.array(a_mx)
e = cm.exception
self.assertTrue("Item size 2 for PEP 3118 buffer format string" in str(e))
# Test buffer protocol with non-arrays ie bytes
a = ord("a") * 257 + mx.arange(10).astype(mx.int16)
ab = bytes(a)
self.assertEqual(len(ab), 20)
if sys.byteorder == "little":
self.assertEqual(b"aaaaaaaaaa", ab[1::2])
self.assertEqual(b"abcdefghij", ab[::2])
else:
self.assertEqual(b"aaaaaaaaaa", ab[::2])
self.assertEqual(b"abcdefghij", ab[1::2])
def test_buffer_protocol_ref_counting(self):
a = mx.arange(3)
wr = weakref.ref(a)
self.assertIsNotNone(wr())
mv = memoryview(a)
a = None
self.assertIsNotNone(wr())
mv = None
self.assertIsNone(wr())
def test_array_view_ref_counting(self):
a = mx.arange(3)
wr = weakref.ref(a)
self.assertIsNotNone(wr())
a_np = np.array(a, copy=False)
a = None
self.assertIsNotNone(wr())
a_np = None
self.assertIsNone(wr())
@unittest.skipIf(not has_tf, "requires TensorFlow")
def test_buffer_protocol_tf(self):
dtypes_list = [
(
mx.bool_,
tf.bool,
np.bool_,
),
(
mx.uint8,
tf.uint8,
np.uint8,
),
(
mx.uint16,
tf.uint16,
np.uint16,
),
(
mx.uint32,
tf.uint32,
np.uint32,
),
(mx.uint64, tf.uint64, np.uint64),
(mx.int8, tf.int8, np.int8),
(mx.int16, tf.int16, np.int16),
(mx.int32, tf.int32, np.int32),
(mx.int64, tf.int64, np.int64),
(mx.float16, tf.float16, np.float16),
(mx.float32, tf.float32, np.float32),
(
mx.complex64,
tf.complex64,
np.complex64,
),
]
for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
a_tf = tf.constant(a_np, dtype=tf_dtype)
a_mx = mx.array(np.array(a_tf))
for f in [
lambda x: x,
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,
]:
mv_mx = memoryview(f(a_mx))
mv_tf = memoryview(f(a_tf))
if (mv_mx.c_contiguous and mv_tf.c_contiguous) or (
mv_mx.f_contiguous and mv_tf.f_contiguous
):
self.assertEqual(
mv_mx.strides, mv_tf.strides, f"{mlx_dtype}{tf_dtype}"
)
self.assertEqual(mv_mx.shape, mv_tf.shape, f"{mlx_dtype}{tf_dtype}")
self.assertFalse(mv_mx.readonly)
back_to_npy = np.array(mv_mx)
self.assertEqualArray(
back_to_npy,
f(a_tf),
atol=0,
rtol=0,
)
def test_logical_overloads(self):
with self.assertRaises(ValueError):
mx.array(1.0) & mx.array(1)
with self.assertRaises(ValueError):
mx.array(1.0) | mx.array(1)
self.assertEqual((mx.array(True) & True).item(), True)
self.assertEqual((mx.array(True) & False).item(), False)
self.assertEqual((mx.array(True) | False).item(), True)
self.assertEqual((mx.array(False) | False).item(), False)
self.assertEqual((~mx.array(False)).item(), True)
def test_inplace(self):
iops = [
"__iadd__",
"__isub__",
"__imul__",
"__ifloordiv__",
"__imod__",
"__ipow__",
]
for op in iops:
a = mx.array([1, 2, 3])
a_np = np.array(a)
b = a
b = getattr(a, op)(3)
self.assertTrue(mx.array_equal(a, b))
out_np = getattr(a_np, op)(3)
self.assertTrue(np.array_equal(out_np, a))
with self.assertRaises(ValueError):
a = mx.array([1])
a /= 1
a = mx.array([2.0])
b = a
b /= 2
self.assertEqual(b.item(), 1.0)
self.assertEqual(b.item(), a.item())
a = mx.array(True)
b = a
b &= False
self.assertEqual(b.item(), False)
self.assertEqual(b.item(), a.item())
a = mx.array(False)
b = a
b |= True
self.assertEqual(b.item(), True)
self.assertEqual(b.item(), a.item())
# In-place matmul on its own
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
b = a
b @= a
self.assertTrue(mx.array_equal(a, b))
def test_inplace_preserves_ids(self):
a = mx.array([1.0])
orig_id = id(a)
a += mx.array(2.0)
self.assertEqual(id(a), orig_id)
a[0] = 2.0
self.assertEqual(id(a), orig_id)
a -= mx.array(3.0)
self.assertEqual(id(a), orig_id)
a *= mx.array(3.0)
self.assertEqual(id(a), orig_id)
def test_load_from_pickled_np(self):
a = np.array([1, 2, 3], dtype=np.int32)
b = pickle.loads(pickle.dumps(a))
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
a = np.array([1.0, 2.0, 3.0], dtype=np.float16)
b = pickle.loads(pickle.dumps(a))
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_multi_output_leak(self):
def fun():
a = mx.zeros((2**20))
mx.eval(a)
b, c = mx.divmod(a, a)
del b, c
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def fun():
a = mx.array([1.0, 2.0, 3.0, 4.0])
b, _ = mx.divmod(a, a)
return mx.log(b)
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def test_add_numpy(self):
x = mx.array(1)
y = np.array(2, dtype=np.int32)
z = x + y
self.assertEqual(z.dtype, mx.int32)
self.assertEqual(z.item(), 3)
def test_dlpack(self):
x = mx.array(1, dtype=mx.int32)
y = np.from_dlpack(x)
self.assertTrue(mx.array_equal(y, x))
x = mx.array([[1.0, 2.0], [3.0, 4.0]])
y = np.from_dlpack(x)
self.assertTrue(mx.array_equal(y, x))
x = mx.arange(16).reshape(4, 4)
x = x[::2, ::2]
y = np.from_dlpack(x)
self.assertTrue(mx.array_equal(y, x))
def test_getitem_with_list(self):
a = mx.array([1, 2, 3, 4, 5])
idx = [0, 2, 4]
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
a = mx.array([[1, 2], [3, 4], [5, 6]])
idx = [0, 2]
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
a = mx.arange(10).reshape(5, 2)
idx = [0, 2, 4]
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
idx = [0, 2]
a = mx.arange(16).reshape(4, 4)
anp = np.array(a)
self.assertTrue(np.array_equal(a[idx, 0], anp[idx, 0]))
self.assertTrue(np.array_equal(a[idx, :], anp[idx, :]))
self.assertTrue(np.array_equal(a[0, idx], anp[0, idx]))
self.assertTrue(np.array_equal(a[:, idx], anp[:, idx]))
def test_setitem_with_list(self):
a = mx.array([1, 2, 3, 4, 5])
anp = np.array(a)
idx = [0, 2, 4]
a[idx] = 3
anp[idx] = 3
self.assertTrue(np.array_equal(a, anp))
a = mx.array([[1, 2], [3, 4], [5, 6]])
idx = [0, 2]
anp = np.array(a)
a[idx] = 3
anp[idx] = 3
self.assertTrue(np.array_equal(a, anp))
a = mx.arange(10).reshape(5, 2)
idx = [0, 2, 4]
anp = np.array(a)
a[idx] = 3
anp[idx] = 3
self.assertTrue(np.array_equal(a, anp))
idx = [0, 2]
a = mx.arange(16).reshape(4, 4)
anp = np.array(a)
a[idx, 0] = 1
anp[idx, 0] = 1
self.assertTrue(np.array_equal(a, anp))
a[idx, :] = 2
anp[idx, :] = 2
self.assertTrue(np.array_equal(a, anp))
a[0, idx] = 3
anp[0, idx] = 3
self.assertTrue(np.array_equal(a, anp))
a[:, idx] = 4
anp[:, idx] = 4
self.assertTrue(np.array_equal(a, anp))
def test_array_namespace(self):
a = mx.array(1.0)
api = a.__array_namespace__()
self.assertTrue(hasattr(api, "array"))
self.assertTrue(hasattr(api, "add"))
def test_to_scalar(self):
a = mx.array(1)
self.assertEqual(int(a), 1)
self.assertEqual(float(a), 1)
a = mx.array(1.5)
self.assertEqual(float(a), 1.5)
self.assertEqual(int(a), 1)
a = mx.zeros((2, 1))
with self.assertRaises(ValueError):
float(a)
with self.assertRaises(ValueError):
int(a)
def test_format(self):
a = mx.arange(3)
self.assertEqual(f"{a[0]:.2f}", "0.00")
b = mx.array(0.35487)
self.assertEqual(f"{b:.1f}", "0.4")
with self.assertRaises(TypeError):
s = f"{a:.2f}"
a = mx.array([1, 2, 3])
self.assertEqual(f"{a}", "array([1, 2, 3], dtype=int32)")
def test_deep_graphs(self):
# The following tests should simply run cleanly without a segfault or
# crash due to exceeding recursion depth limits.
# Deep graph destroyed without eval
x = mx.array([1.0, 2.0])
for _ in range(100_000):
x = mx.sin(x)
del x
# Duplicate input deep graph destroyed without eval
x = mx.array([1.0, 2.0])
for _ in range(100_000):
x = x + x
# Deep graph with siblings destroyed without eval
x = mx.array([1, 2])
for _ in range(100_000):
x = mx.concatenate(mx.split(x, 2))
del x
# Deep graph with eval
x = mx.array([1.0, 2.0])
for _ in range(100_000):
x = mx.sin(x)
mx.eval(x)
def test_siblings_without_eval(self):
def get_mem():
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
key = mx.array([1, 2])
def t():
a, b = mx.split(key, 2)
a = mx.reshape(a, [])
b = mx.reshape(b, [])
return b
t()
gc.collect()
expected = get_mem()
for _ in range(100):
t()
used = get_mem()
self.assertEqual(expected, used)
if __name__ == "__main__":
unittest.main()