Fix empty array construction in cpp (#684)

This commit is contained in:
Awni Hannun 2024-02-13 23:34:17 -08:00 committed by GitHub
parent 0c65517e91
commit 1eb04aa23f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 29 additions and 1 deletions

View File

@ -82,6 +82,13 @@ array::array(std::initializer_list<float> data)
init(data.begin()); init(data.begin());
} }
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
dtype)) {
init(data.begin());
}
/* Build an array from a shared buffer */ /* Build an array from a shared buffer */
array::array( array::array(
allocator::Buffer data, allocator::Buffer data,

View File

@ -41,6 +41,9 @@ class array {
/* Special case so empty lists default to float32. */ /* Special case so empty lists default to float32. */
array(std::initializer_list<float> data); array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */
array(std::initializer_list<int> data, Dtype dtype);
template <typename T> template <typename T>
array( array(
std::initializer_list<T> data, std::initializer_list<T> data,

View File

@ -774,7 +774,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
} }
if (repeats == 0) { if (repeats == 0) {
return array(std::initializer_list<int>{}, arr.dtype()); return array({}, arr.dtype());
} }
if (repeats == 1) { if (repeats == 1) {

View File

@ -591,3 +591,21 @@ TEST_CASE("test array shared buffer") {
eval(a + b); eval(a + b);
} }
TEST_CASE("test make empty array") {
auto a = array({});
CHECK_EQ(a.size(), 0);
CHECK_EQ(a.dtype(), float32);
a = array({}, int32);
CHECK_EQ(a.size(), 0);
CHECK_EQ(a.dtype(), int32);
a = array({}, float32);
CHECK_EQ(a.size(), 0);
CHECK_EQ(a.dtype(), float32);
a = array({}, bool_);
CHECK_EQ(a.size(), 0);
CHECK_EQ(a.dtype(), bool_);
}