fix build w/ flatten (#195)

This commit is contained in:
Awni Hannun
2023-12-17 11:58:45 -08:00
committed by GitHub
parent 52e1589a52
commit 90d04072b7
3 changed files with 30 additions and 9 deletions

View File

@@ -283,21 +283,36 @@ array flatten(
int end_axis /* = -1 */,
StreamOrDevice s /* = {} */) {
auto ndim = static_cast<int>(a.ndim());
start_axis += (start_axis < 0 ? ndim : 0);
end_axis += (end_axis < 0 ? ndim + 1 : 0);
start_axis = std::max(0, start_axis);
end_axis = std::min(ndim, end_axis);
if (end_axis < start_axis) {
auto start_ax = start_axis + (start_axis < 0 ? ndim : 0);
auto end_ax = end_axis + (end_axis < 0 ? ndim : 0);
start_ax = std::max(0, start_ax);
end_ax = std::min(ndim - 1, end_ax);
if (a.ndim() == 0) {
return reshape(a, {1}, s);
}
if (end_ax < start_ax) {
throw std::invalid_argument(
"[flatten] start_axis must be less than or equal to end_axis");
}
if (start_axis == end_axis and a.ndim() != 0) {
if (start_ax >= ndim) {
std::ostringstream msg;
msg << "[flatten] Invalid start_axis " << start_axis << " for array with "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (end_ax < 0) {
std::ostringstream msg;
msg << "[flatten] Invalid end_axis " << end_axis << " for array with "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (start_ax == end_ax) {
return a;
}
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_axis);
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_ax);
new_shape.push_back(-1);
new_shape.insert(
new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end());
new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end());
return reshape(a, new_shape, s);
}