diff --git a/mlx/array.cpp b/mlx/array.cpp index 7f3dd854b..83c2fe6d7 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -82,6 +82,13 @@ array::array(std::initializer_list data) init(data.begin()); } +array::array(std::initializer_list data, Dtype dtype) + : array_desc_(std::make_shared( + std::vector{static_cast(data.size())}, + dtype)) { + init(data.begin()); +} + /* Build an array from a shared buffer */ array::array( allocator::Buffer data, diff --git a/mlx/array.h b/mlx/array.h index 5eefcf727..fe01cbfd7 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -41,6 +41,9 @@ class array { /* Special case so empty lists default to float32. */ array(std::initializer_list data); + /* Special case so array({}, type) is an empty array. */ + array(std::initializer_list data, Dtype dtype); + template array( std::initializer_list data, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 96107d515..01ee6d388 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -774,7 +774,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { } if (repeats == 0) { - return array(std::initializer_list{}, arr.dtype()); + return array({}, arr.dtype()); } if (repeats == 1) { diff --git a/tests/array_tests.cpp b/tests/array_tests.cpp index 080d53daa..62341c5c7 100644 --- a/tests/array_tests.cpp +++ b/tests/array_tests.cpp @@ -591,3 +591,21 @@ TEST_CASE("test array shared buffer") { eval(a + b); } + +TEST_CASE("test make empty array") { + auto a = array({}); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), float32); + + a = array({}, int32); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), int32); + + a = array({}, float32); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), float32); + + a = array({}, bool_); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), bool_); +}