mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -817,10 +817,10 @@ std::vector<array> Concatenate::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
auto& cotan = cotangents[0];
|
||||
std::vector<int> start(cotan.ndim(), 0);
|
||||
std::vector<int> stop = cotan.shape();
|
||||
Shape start(cotan.ndim(), 0);
|
||||
Shape stop = cotan.shape();
|
||||
|
||||
std::vector<int> sizes;
|
||||
Shape sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : primals) {
|
||||
sizes.push_back(p.shape(axis_));
|
||||
@@ -956,9 +956,9 @@ array conv_weight_backward_patches(
|
||||
const std::vector<int>& padding,
|
||||
StreamOrDevice s) {
|
||||
// Resolve Padded input shapes and strides
|
||||
std::vector<int> padding_starts(in.ndim(), 0);
|
||||
std::vector<int> padding_ends = in.shape();
|
||||
std::vector<int> in_padded_shape = in.shape();
|
||||
Shape padding_starts(in.ndim(), 0);
|
||||
auto padding_ends = in.shape();
|
||||
auto in_padded_shape = in.shape();
|
||||
|
||||
// padded shape
|
||||
for (int i = 1; i < in.ndim() - 1; i++) {
|
||||
@@ -976,8 +976,9 @@ array conv_weight_backward_patches(
|
||||
// Pad input
|
||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||
Shape padding_(padding.begin(), padding.end());
|
||||
auto in_padded = pad(
|
||||
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s);
|
||||
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s);
|
||||
|
||||
// Resolve strided patches
|
||||
|
||||
@@ -1797,7 +1798,7 @@ std::vector<array> FFT::vjp(
|
||||
std::vector<int> axes(axes_.begin(), axes_.end());
|
||||
if (real_ && inverse_) {
|
||||
auto out = fft::fftn(cotangents[0], axes, stream());
|
||||
auto start = std::vector<int>(out.ndim(), 0);
|
||||
auto start = Shape(out.ndim(), 0);
|
||||
auto stop = in.shape();
|
||||
out = slice(out, start, stop, stream());
|
||||
auto mask_shape = out.shape();
|
||||
@@ -1809,7 +1810,7 @@ std::vector<array> FFT::vjp(
|
||||
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
|
||||
return {multiply(mask, out, stream())};
|
||||
} else if (real_) {
|
||||
std::vector<int> n;
|
||||
Shape n;
|
||||
for (auto ax : axes_) {
|
||||
n.push_back(in.shape()[ax]);
|
||||
}
|
||||
@@ -1934,10 +1935,11 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
|
||||
}
|
||||
if (indices_vmapped) {
|
||||
// Make a new index array for the vmapped dimension
|
||||
auto vmap_inds = arange(0, src.shape(axes[0]), stream());
|
||||
auto vmap_inds =
|
||||
arange(static_cast<ShapeElem>(0), src.shape(axes[0]), stream());
|
||||
// Reshape it so it broadcasts with other index arrays
|
||||
{
|
||||
auto shape = std::vector<int>(idx_dims, 1);
|
||||
auto shape = Shape(idx_dims, 1);
|
||||
shape[out_ax] = vmap_inds.size();
|
||||
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
|
||||
}
|
||||
@@ -2628,8 +2630,8 @@ std::vector<array> Pad::vjp(
|
||||
assert(argnums.size() == 1 && argnums[0] == 0);
|
||||
|
||||
auto& cotan = cotangents[0];
|
||||
std::vector<int> start(cotan.ndim(), 0);
|
||||
std::vector<int> stop = cotan.shape();
|
||||
Shape start(cotan.ndim(), 0);
|
||||
auto stop = cotan.shape();
|
||||
|
||||
for (auto i : axes_) {
|
||||
start[i] = low_pad_size_[i];
|
||||
@@ -3019,7 +3021,7 @@ std::vector<array> Reduce::vjp(
|
||||
const std::vector<array>& outputs) {
|
||||
auto in = primals[0];
|
||||
|
||||
std::vector<int> shape = in.shape();
|
||||
auto shape = in.shape();
|
||||
for (auto ax : axes_) {
|
||||
shape[ax] = 1;
|
||||
}
|
||||
@@ -3044,7 +3046,7 @@ std::vector<array> Reduce::vjp(
|
||||
if (axes_.size() > 1) {
|
||||
std::vector<int> transpose_to;
|
||||
std::vector<int> transpose_back;
|
||||
std::vector<int> shape_flat;
|
||||
Shape shape_flat;
|
||||
{
|
||||
// Find the transpose needed to move axes_ to the back and the shape
|
||||
// except the reduced over axes.
|
||||
@@ -3422,7 +3424,7 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
|
||||
}
|
||||
|
||||
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
|
||||
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1);
|
||||
auto vmap_inds_shape = Shape(inputs[1].ndim(), 1);
|
||||
vmap_inds_shape[0] = vmap_inds.size();
|
||||
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
|
||||
inputs.insert(
|
||||
@@ -3607,7 +3609,7 @@ std::vector<array> Slice::vjp(
|
||||
// Transpose and reshape cotangents
|
||||
auto cotan = cotangents[0];
|
||||
if (!ind_axes.empty()) {
|
||||
std::vector<int> cotan_shape;
|
||||
Shape cotan_shape;
|
||||
for (auto ax : ind_axes) {
|
||||
cotan_shape.push_back(cotan.shape(ax));
|
||||
}
|
||||
@@ -3626,7 +3628,7 @@ std::vector<array> Slice::vjp(
|
||||
}
|
||||
|
||||
// Make indices broadcastable
|
||||
std::vector<int> inds_shape(inds.size(), 1);
|
||||
Shape inds_shape(inds.size(), 1);
|
||||
for (int i = 0; i < inds.size(); ++i) {
|
||||
inds_shape[i] = inds[i].size();
|
||||
inds[i] = reshape(inds[i], inds_shape, stream());
|
||||
@@ -4184,7 +4186,7 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
// Slice mask
|
||||
mask_reshape[mask_ndim - 2] = Y;
|
||||
mask_reshape[mask_ndim - 1] = X;
|
||||
mask = slice(mask, std::vector<int>(mask_ndim, 0), mask_reshape, stream());
|
||||
mask = slice(mask, Shape(mask_ndim, 0), mask_reshape, stream());
|
||||
|
||||
return mask;
|
||||
};
|
||||
@@ -4202,7 +4204,7 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
}
|
||||
|
||||
// Reshape
|
||||
std::vector<int> r_reshape(r.shape().begin(), r.shape().end() - 2);
|
||||
Shape r_reshape(r.shape().begin(), r.shape().end() - 2);
|
||||
r_reshape.push_back(r.shape(-2) / block_size_);
|
||||
r_reshape.push_back(block_size_);
|
||||
r_reshape.push_back(r.shape(-1) / block_size_);
|
||||
@@ -4492,7 +4494,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
||||
}
|
||||
|
||||
array out = array(
|
||||
std::vector<int>{},
|
||||
{},
|
||||
dtype_,
|
||||
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||
inputs);
|
||||
|
||||
Reference in New Issue
Block a user