Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -144,8 +144,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
throw std::invalid_argument(msg.str());
}
std::vector<std::vector<int>> Primitive::output_shapes(
const std::vector<array>&) {
std::vector<Shape> Primitive::output_shapes(const std::vector<array>&) {
std::ostringstream msg;
msg << "[Primitive::output_shapes] ";
this->print(msg);
@@ -969,7 +968,7 @@ array conv_weight_backward_patches(
}
// padded strides (contiguous)
std::vector<size_t> in_padded_strides(in.ndim(), 1);
Strides in_padded_strides(in.ndim(), 1);
for (int i = in.ndim() - 2; i >= 0; --i) {
in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1];
}
@@ -984,14 +983,13 @@ array conv_weight_backward_patches(
// patches are shaped as
// (batch_dim, out_spatial_dims, weight_spatial_dims, in_channels)
std::vector<int> patches_shape{
cotan.shape().begin(), cotan.shape().end() - 1};
Shape patches_shape{cotan.shape().begin(), cotan.shape().end() - 1};
patches_shape.insert(
patches_shape.end(), wt.shape().begin() + 1, wt.shape().end());
// Resolve patch strides
int n_spatial_dim = in.ndim() - 2;
std::vector<size_t> patches_strides(patches_shape.size(), 1);
Strides patches_strides(patches_shape.size(), 1);
patches_strides[0] = in_padded_strides[0];
for (int i = 1; i < n_spatial_dim + 1; i++) {
patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1];
@@ -1095,8 +1093,8 @@ std::vector<array> Convolution::vjp(
// Handle negative padding
if (has_neg_padding) {
std::vector<int> starts(grad.ndim(), 0);
std::vector<int> stops = grad.shape();
Shape starts(grad.ndim(), 0);
auto stops = grad.shape();
for (int i = 0; i < grad.ndim() - 2; i++) {
if (padding_lo[i] < 0) {