Fix reshaping of empty arrays (#791)

This commit is contained in:
Angelos Katharopoulos 2024-03-05 23:33:22 -08:00 committed by GitHub
parent 14b4e51a7c
commit e39bebe13e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -293,26 +293,35 @@ array reshape(
for (int i = 0; i < shape.size(); ++i) { for (int i = 0; i < shape.size(); ++i) {
if (shape[i] == -1) { if (shape[i] == -1) {
if (infer_idx >= 0) { if (infer_idx >= 0) {
throw std::invalid_argument("Reshape can only infer one dimension."); throw std::invalid_argument(
"[reshape] Reshape can only infer one dimension.");
} }
infer_idx = i; infer_idx = i;
} else { } else {
size *= shape[i]; size *= shape[i];
} }
} }
// Infer the shape
if (size > 0) { if (size > 0) {
auto q_and_r = std::ldiv(a.size(), size); auto q_and_r = std::ldiv(a.size(), size);
if (infer_idx >= 0) { if (infer_idx >= 0) {
shape[infer_idx] = q_and_r.quot; shape[infer_idx] = q_and_r.quot;
size *= q_and_r.quot; size *= q_and_r.quot;
} }
} else if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Cannot infer the shape of an empty array");
} }
// Check the the reshaping is valid
if (a.size() != size) { if (a.size() != size) {
std::ostringstream msg; std::ostringstream msg;
msg << "Cannot reshape array of size " << a.size() << " into shape " msg << "[reshape] Cannot reshape array of size " << a.size()
<< shape << "."; << " into shape " << shape << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
return array( return array(
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a}); shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
} }