More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -649,7 +649,7 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
// Clamp to bounds
auto st = std::min(s, n - 1);
auto ed = std::max(-1, e);
auto ed = e > -1 ? e : -1;
start[i] = st;
stop[i] = ed > st ? st : ed;
@@ -659,8 +659,8 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
} else {
// Clamp to bounds
auto st = std::max(0, std::min(s, n));
auto ed = std::max(0, std::min(e, n));
auto st = std::max(static_cast<ShapeElem>(0), std::min(s, n));
auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
start[i] = st;
stop[i] = ed < st ? st : ed;
@@ -765,7 +765,7 @@ array slice_update(
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
const Shape& indices,
int axis,
StreamOrDevice s /* = {} */) {
auto ax = axis < 0 ? axis + a.ndim() : axis;
@@ -809,10 +809,8 @@ std::vector<array> split(
return res;
}
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
StreamOrDevice s /* = {} */) {
std::vector<array>
split(const array& a, const Shape& indices, StreamOrDevice s /* = {} */) {
return split(a, indices, 0, s);
}
@@ -834,7 +832,7 @@ split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg.str());
}
auto split_size = q_and_r.quot;
std::vector<int> indices(num_splits - 1);
Shape indices(num_splits - 1);
for (int i = 0; i < indices.size(); ++i) {
indices[i] = (i + 1) * split_size;
}
@@ -1104,7 +1102,7 @@ array edge_pad(
/** Pad an array with a constant value */
array pad(
const array& a,
const Shape& axes,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Shape& high_pad_size,
const array& pad_value /*= array(0)*/,
@@ -1904,9 +1902,11 @@ array min(
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size();
auto result = argmin(reshape(a, {size}, s), 0, true, s);
auto result = argmin(flatten(a, s), 0, true, s);
if (keepdims) {
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
std::vector<int> axes(a.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
result = expand_dims(result, axes, s);
} else {
result = squeeze(result, s);
}
@@ -1940,9 +1940,11 @@ array argmin(
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size();
auto result = argmax(reshape(a, {size}, s), 0, true, s);
auto result = argmax(flatten(a, s), 0, true, s);
if (keepdims) {
result = reshape(result, Shape(a.shape().size(), 1), s);
std::vector<int> axes(a.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
result = expand_dims(result, axes, s);
} else {
result = squeeze(result, s);
}
@@ -3238,8 +3240,8 @@ inline int dilate_size(int dim, int dil) {
}
Shape conv_out_shape(
const std::vector<int>& in_shape,
const std::vector<int>& wt_shape,
const Shape& in_shape,
const Shape& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
@@ -4329,16 +4331,16 @@ array diagonal(
"[diagonal] axis1 and axis2 cannot be the same axis");
}
auto off1 = std::max(-offset, 0);
auto off2 = std::max(offset, 0);
ShapeElem off1 = std::max(-offset, 0);
ShapeElem off2 = std::max(offset, 0);
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
diag_size = std::max(diag_size, 0);
diag_size = diag_size < 0 ? 0 : diag_size;
std::vector<array> indices = {
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
std::vector<int> slice_sizes = a.shape();
Shape slice_sizes = a.shape();
slice_sizes[ax1] = 1;
slice_sizes[ax2] = 1;