More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -817,10 +817,10 @@ std::vector<array> Concatenate::vjp(
const std::vector<int>& argnums,
const std::vector<array>&) {
auto& cotan = cotangents[0];
std::vector<int> start(cotan.ndim(), 0);
std::vector<int> stop = cotan.shape();
Shape start(cotan.ndim(), 0);
Shape stop = cotan.shape();
std::vector<int> sizes;
Shape sizes;
sizes.push_back(0);
for (auto& p : primals) {
sizes.push_back(p.shape(axis_));
@@ -956,9 +956,9 @@ array conv_weight_backward_patches(
const std::vector<int>& padding,
StreamOrDevice s) {
// Resolve Padded input shapes and strides
std::vector<int> padding_starts(in.ndim(), 0);
std::vector<int> padding_ends = in.shape();
std::vector<int> in_padded_shape = in.shape();
Shape padding_starts(in.ndim(), 0);
auto padding_ends = in.shape();
auto in_padded_shape = in.shape();
// padded shape
for (int i = 1; i < in.ndim() - 1; i++) {
@@ -976,8 +976,9 @@ array conv_weight_backward_patches(
// Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1);
Shape padding_(padding.begin(), padding.end());
auto in_padded = pad(
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s);
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s);
// Resolve strided patches
@@ -1797,7 +1798,7 @@ std::vector<array> FFT::vjp(
std::vector<int> axes(axes_.begin(), axes_.end());
if (real_ && inverse_) {
auto out = fft::fftn(cotangents[0], axes, stream());
auto start = std::vector<int>(out.ndim(), 0);
auto start = Shape(out.ndim(), 0);
auto stop = in.shape();
out = slice(out, start, stop, stream());
auto mask_shape = out.shape();
@@ -1809,7 +1810,7 @@ std::vector<array> FFT::vjp(
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
return {multiply(mask, out, stream())};
} else if (real_) {
std::vector<int> n;
Shape n;
for (auto ax : axes_) {
n.push_back(in.shape()[ax]);
}
@@ -1934,10 +1935,11 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
}
if (indices_vmapped) {
// Make a new index array for the vmapped dimension
auto vmap_inds = arange(0, src.shape(axes[0]), stream());
auto vmap_inds =
arange(static_cast<ShapeElem>(0), src.shape(axes[0]), stream());
// Reshape it so it broadcasts with other index arrays
{
auto shape = std::vector<int>(idx_dims, 1);
auto shape = Shape(idx_dims, 1);
shape[out_ax] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
}
@@ -2628,8 +2630,8 @@ std::vector<array> Pad::vjp(
assert(argnums.size() == 1 && argnums[0] == 0);
auto& cotan = cotangents[0];
std::vector<int> start(cotan.ndim(), 0);
std::vector<int> stop = cotan.shape();
Shape start(cotan.ndim(), 0);
auto stop = cotan.shape();
for (auto i : axes_) {
start[i] = low_pad_size_[i];
@@ -3019,7 +3021,7 @@ std::vector<array> Reduce::vjp(
const std::vector<array>& outputs) {
auto in = primals[0];
std::vector<int> shape = in.shape();
auto shape = in.shape();
for (auto ax : axes_) {
shape[ax] = 1;
}
@@ -3044,7 +3046,7 @@ std::vector<array> Reduce::vjp(
if (axes_.size() > 1) {
std::vector<int> transpose_to;
std::vector<int> transpose_back;
std::vector<int> shape_flat;
Shape shape_flat;
{
// Find the transpose needed to move axes_ to the back and the shape
// except the reduced over axes.
@@ -3422,7 +3424,7 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
}
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1);
auto vmap_inds_shape = Shape(inputs[1].ndim(), 1);
vmap_inds_shape[0] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
inputs.insert(
@@ -3607,7 +3609,7 @@ std::vector<array> Slice::vjp(
// Transpose and reshape cotangents
auto cotan = cotangents[0];
if (!ind_axes.empty()) {
std::vector<int> cotan_shape;
Shape cotan_shape;
for (auto ax : ind_axes) {
cotan_shape.push_back(cotan.shape(ax));
}
@@ -3626,7 +3628,7 @@ std::vector<array> Slice::vjp(
}
// Make indices broadcastable
std::vector<int> inds_shape(inds.size(), 1);
Shape inds_shape(inds.size(), 1);
for (int i = 0; i < inds.size(); ++i) {
inds_shape[i] = inds[i].size();
inds[i] = reshape(inds[i], inds_shape, stream());
@@ -4184,7 +4186,7 @@ std::vector<array> BlockMaskedMM::vjp(
// Slice mask
mask_reshape[mask_ndim - 2] = Y;
mask_reshape[mask_ndim - 1] = X;
mask = slice(mask, std::vector<int>(mask_ndim, 0), mask_reshape, stream());
mask = slice(mask, Shape(mask_ndim, 0), mask_reshape, stream());
return mask;
};
@@ -4202,7 +4204,7 @@ std::vector<array> BlockMaskedMM::vjp(
}
// Reshape
std::vector<int> r_reshape(r.shape().begin(), r.shape().end() - 2);
Shape r_reshape(r.shape().begin(), r.shape().end() - 2);
r_reshape.push_back(r.shape(-2) / block_size_);
r_reshape.push_back(block_size_);
r_reshape.push_back(r.shape(-1) / block_size_);
@@ -4492,7 +4494,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
}
array out = array(
std::vector<int>{},
{},
dtype_,
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
inputs);