mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix empty array construction in cpp (#684)
This commit is contained in:
parent
0c65517e91
commit
1eb04aa23f
@ -82,6 +82,13 @@ array::array(std::initializer_list<float> data)
|
|||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
|
std::vector<int>{static_cast<int>(data.size())},
|
||||||
|
dtype)) {
|
||||||
|
init(data.begin());
|
||||||
|
}
|
||||||
|
|
||||||
/* Build an array from a shared buffer */
|
/* Build an array from a shared buffer */
|
||||||
array::array(
|
array::array(
|
||||||
allocator::Buffer data,
|
allocator::Buffer data,
|
||||||
|
@ -41,6 +41,9 @@ class array {
|
|||||||
/* Special case so empty lists default to float32. */
|
/* Special case so empty lists default to float32. */
|
||||||
array(std::initializer_list<float> data);
|
array(std::initializer_list<float> data);
|
||||||
|
|
||||||
|
/* Special case so array({}, type) is an empty array. */
|
||||||
|
array(std::initializer_list<int> data, Dtype dtype);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array(
|
array(
|
||||||
std::initializer_list<T> data,
|
std::initializer_list<T> data,
|
||||||
|
@ -774,7 +774,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (repeats == 0) {
|
if (repeats == 0) {
|
||||||
return array(std::initializer_list<int>{}, arr.dtype());
|
return array({}, arr.dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (repeats == 1) {
|
if (repeats == 1) {
|
||||||
|
@ -591,3 +591,21 @@ TEST_CASE("test array shared buffer") {
|
|||||||
|
|
||||||
eval(a + b);
|
eval(a + b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test make empty array") {
|
||||||
|
auto a = array({});
|
||||||
|
CHECK_EQ(a.size(), 0);
|
||||||
|
CHECK_EQ(a.dtype(), float32);
|
||||||
|
|
||||||
|
a = array({}, int32);
|
||||||
|
CHECK_EQ(a.size(), 0);
|
||||||
|
CHECK_EQ(a.dtype(), int32);
|
||||||
|
|
||||||
|
a = array({}, float32);
|
||||||
|
CHECK_EQ(a.size(), 0);
|
||||||
|
CHECK_EQ(a.dtype(), float32);
|
||||||
|
|
||||||
|
a = array({}, bool_);
|
||||||
|
CHECK_EQ(a.size(), 0);
|
||||||
|
CHECK_EQ(a.dtype(), bool_);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user