mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -144,8 +144,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> Primitive::output_shapes(
|
||||
const std::vector<array>&) {
|
||||
std::vector<Shape> Primitive::output_shapes(const std::vector<array>&) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::output_shapes] ";
|
||||
this->print(msg);
|
||||
@@ -969,7 +968,7 @@ array conv_weight_backward_patches(
|
||||
}
|
||||
|
||||
// padded strides (contiguous)
|
||||
std::vector<size_t> in_padded_strides(in.ndim(), 1);
|
||||
Strides in_padded_strides(in.ndim(), 1);
|
||||
for (int i = in.ndim() - 2; i >= 0; --i) {
|
||||
in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1];
|
||||
}
|
||||
@@ -984,14 +983,13 @@ array conv_weight_backward_patches(
|
||||
|
||||
// patches are shaped as
|
||||
// (batch_dim, out_spatial_dims, weight_spatial_dims, in_channels)
|
||||
std::vector<int> patches_shape{
|
||||
cotan.shape().begin(), cotan.shape().end() - 1};
|
||||
Shape patches_shape{cotan.shape().begin(), cotan.shape().end() - 1};
|
||||
patches_shape.insert(
|
||||
patches_shape.end(), wt.shape().begin() + 1, wt.shape().end());
|
||||
|
||||
// Resolve patch strides
|
||||
int n_spatial_dim = in.ndim() - 2;
|
||||
std::vector<size_t> patches_strides(patches_shape.size(), 1);
|
||||
Strides patches_strides(patches_shape.size(), 1);
|
||||
patches_strides[0] = in_padded_strides[0];
|
||||
for (int i = 1; i < n_spatial_dim + 1; i++) {
|
||||
patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1];
|
||||
@@ -1095,8 +1093,8 @@ std::vector<array> Convolution::vjp(
|
||||
|
||||
// Handle negative padding
|
||||
if (has_neg_padding) {
|
||||
std::vector<int> starts(grad.ndim(), 0);
|
||||
std::vector<int> stops = grad.shape();
|
||||
Shape starts(grad.ndim(), 0);
|
||||
auto stops = grad.shape();
|
||||
|
||||
for (int i = 0; i < grad.ndim() - 2; i++) {
|
||||
if (padding_lo[i] < 0) {
|
||||
|
||||
Reference in New Issue
Block a user