support python mlx.array creation from list of mlx.array's (#325)

* support python mlx.array creation from list of mlx.array's

* include bfloat16 in UT

* refactor so that sub array made of all python primitive types gets initialized by fill_vector

* address PR comment: arr.shape().size() -> arr.ndim()

* address PR comment: get back Dtype constness and let stack to handle type promotions automatically
This commit is contained in:
mutexuan
2024-01-05 10:53:33 +08:00
committed by GitHub
parent b9e415d19c
commit d8f41a5c0f
3 changed files with 180 additions and 63 deletions

View File

@@ -218,6 +218,64 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([1 + 0j, 2j, True, 0], mx.complex64)
self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j])
def test_construction_from_lists_of_mlx_arrays(self):
dtypes = [
mx.bool_,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.float16,
mx.float32,
mx.bfloat16,
mx.complex64,
]
for x_t, y_t in permutations(dtypes, 2):
# check type promotion and numeric correctness
x, y = mx.array([1.0], x_t), mx.array([2.0], y_t)
z = mx.array([x, y])
expected = mx.stack([x, y], axis=0)
self.assertEqualArray(z, expected)
# check heterogeneous construction with mlx arrays and python primitive types
x, y = mx.array([True], x_t), mx.array([False], y_t)
z = mx.array([[x, [2.0]], [[3.0], y]])
expected = mx.array([[[x.item()], [2.0]], [[3.0], [y.item()]]], z.dtype)
self.assertEqualArray(z, expected)
# check when create from an array which does not contain memory to the raw data
x = mx.array([1.0]).astype(mx.bfloat16) # x does not hold raw data
for y_t in dtypes:
y = mx.array([2.0], y_t)
z = mx.array([x, y])
expected = mx.stack([x, y], axis=0)
self.assertEqualArray(z, expected)
# shape check from `stack()`
with self.assertRaises(ValueError) as e:
mx.array([x, 1.0])
self.assertEqual(str(e.exception), "All arrays must have the same shape")
# shape check from `validate_shape`
with self.assertRaises(ValueError) as e:
mx.array([1.0, x])
self.assertEqual(
str(e.exception), "Initialization encountered non-uniform length."
)
# check that `[mx.array, ...]` retains the `mx.array` in the graph
def f(x):
y = mx.array([x, mx.array([2.0])])
return (2 * y).sum()
x = mx.array([1.0])
dfdx = mx.grad(f)
self.assertEqual(dfdx(x).item(), 2.0)
def test_init_from_array(self):
x = mx.array(3.0)
y = mx.array(x)