Fix reshaping of empty arrays (#791)

This commit is contained in:
Angelos Katharopoulos 2024-03-05 23:33:22 -08:00 committed by GitHub
parent 14b4e51a7c
commit e39bebe13e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -293,26 +293,35 @@ array reshape(
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] == -1) {
if (infer_idx >= 0) {
throw std::invalid_argument("Reshape can only infer one dimension.");
throw std::invalid_argument(
"[reshape] Reshape can only infer one dimension.");
}
infer_idx = i;
} else {
size *= shape[i];
}
}
// Infer the shape
if (size > 0) {
auto q_and_r = std::ldiv(a.size(), size);
if (infer_idx >= 0) {
shape[infer_idx] = q_and_r.quot;
size *= q_and_r.quot;
}
} else if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Cannot infer the shape of an empty array");
}
// Check the the reshaping is valid
if (a.size() != size) {
std::ostringstream msg;
msg << "Cannot reshape array of size " << a.size() << " into shape "
<< shape << ".";
msg << "[reshape] Cannot reshape array of size " << a.size()
<< " into shape " << shape << ".";
throw std::invalid_argument(msg.str());
}
return array(
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
}