mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
support python mlx.array creation from list of mlx.array's (#325)
* support python mlx.array creation from list of mlx.array's * include bfloat16 in UT * refactor so that sub array made of all python primitive types gets initialized by fill_vector * address PR comment: arr.shape().size() -> arr.ndim() * address PR comment: get back Dtype constness and let stack to handle type promotions automatically
This commit is contained in:
@@ -218,6 +218,64 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array([1 + 0j, 2j, True, 0], mx.complex64)
|
||||
self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j])
|
||||
|
||||
def test_construction_from_lists_of_mlx_arrays(self):
|
||||
dtypes = [
|
||||
mx.bool_,
|
||||
mx.uint8,
|
||||
mx.uint16,
|
||||
mx.uint32,
|
||||
mx.uint64,
|
||||
mx.int8,
|
||||
mx.int16,
|
||||
mx.int32,
|
||||
mx.int64,
|
||||
mx.float16,
|
||||
mx.float32,
|
||||
mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
for x_t, y_t in permutations(dtypes, 2):
|
||||
# check type promotion and numeric correctness
|
||||
x, y = mx.array([1.0], x_t), mx.array([2.0], y_t)
|
||||
z = mx.array([x, y])
|
||||
expected = mx.stack([x, y], axis=0)
|
||||
self.assertEqualArray(z, expected)
|
||||
|
||||
# check heterogeneous construction with mlx arrays and python primitive types
|
||||
x, y = mx.array([True], x_t), mx.array([False], y_t)
|
||||
z = mx.array([[x, [2.0]], [[3.0], y]])
|
||||
expected = mx.array([[[x.item()], [2.0]], [[3.0], [y.item()]]], z.dtype)
|
||||
self.assertEqualArray(z, expected)
|
||||
|
||||
# check when create from an array which does not contain memory to the raw data
|
||||
x = mx.array([1.0]).astype(mx.bfloat16) # x does not hold raw data
|
||||
for y_t in dtypes:
|
||||
y = mx.array([2.0], y_t)
|
||||
z = mx.array([x, y])
|
||||
expected = mx.stack([x, y], axis=0)
|
||||
self.assertEqualArray(z, expected)
|
||||
|
||||
# shape check from `stack()`
|
||||
with self.assertRaises(ValueError) as e:
|
||||
mx.array([x, 1.0])
|
||||
self.assertEqual(str(e.exception), "All arrays must have the same shape")
|
||||
|
||||
# shape check from `validate_shape`
|
||||
with self.assertRaises(ValueError) as e:
|
||||
mx.array([1.0, x])
|
||||
self.assertEqual(
|
||||
str(e.exception), "Initialization encountered non-uniform length."
|
||||
)
|
||||
|
||||
# check that `[mx.array, ...]` retains the `mx.array` in the graph
|
||||
def f(x):
|
||||
y = mx.array([x, mx.array([2.0])])
|
||||
return (2 * y).sum()
|
||||
|
||||
x = mx.array([1.0])
|
||||
dfdx = mx.grad(f)
|
||||
self.assertEqual(dfdx(x).item(), 2.0)
|
||||
|
||||
def test_init_from_array(self):
|
||||
x = mx.array(3.0)
|
||||
y = mx.array(x)
|
||||
|
Reference in New Issue
Block a user