diff --git a/python/src/array.cpp b/python/src/array.cpp index 392eb34a2..afb691428 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -229,9 +229,28 @@ array array_from_list( return array(vals.begin(), shape, specified_type.value_or(bool_)); } case pyint: { - std::vector vals; - fill_vector(pl, vals); - return array(vals.begin(), shape, specified_type.value_or(int32)); + auto dtype = specified_type.value_or(int32); + if (dtype == int64) { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, dtype); + } else if (dtype == uint64) { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, dtype); + } else if (dtype == uint32) { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, dtype); + } else if (is_floating_point(dtype)) { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, dtype); + } else { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, dtype); + } } case pyfloat: { std::vector vals; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 94a9396b6..593dde361 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -226,6 +226,14 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array([1 + 0j, 2j, True, 0], mx.complex64) self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j]) + xnp = np.array([0, 4294967295], dtype=np.uint32) + x = mx.array([0, 4294967295], dtype=mx.uint32) + self.assertTrue(np.array_equal(x, xnp)) + + xnp = np.array([0, 4294967295], dtype=np.float32) + x = mx.array([0, 4294967295], dtype=mx.float32) + self.assertTrue(np.array_equal(x, xnp)) + def test_construction_from_lists_of_mlx_arrays(self): dtypes = [ mx.bool_,