Fix a couple of slicing bugs (#1827)

* fix a few bugs

* fix conv grad

* speedup test

* comment
This commit is contained in:
Awni Hannun
2025-02-05 19:50:08 -08:00
committed by GitHub
parent 9174606d4c
commit af1b725fda
14 changed files with 170 additions and 107 deletions

View File

@@ -35,4 +35,29 @@ void shared_buffer_slice(
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
}
void slice(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
// Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
int64_t data_end = 1;
for (int i = 0; i < start_indices.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
// data_end can be -1
size_t data_size =
data_end < 0 ? (data_offset - data_end) : (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}
} // namespace mlx::core

View File

@@ -11,11 +11,10 @@ std::tuple<int64_t, Strides> prepare_slice(
const Shape& start_indices,
const Shape& strides);
void shared_buffer_slice(
void slice(
const array& in,
const Strides& out_strides,
size_t data_offset,
size_t data_size,
array& out);
array& out,
const Shape& start_indices,
const Shape& strides);
} // namespace mlx::core

View File

@@ -86,7 +86,7 @@ void NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Slice::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
slice(inputs[0], out, start_indices_, strides_);
}
void Split::eval_cpu(
const std::vector<array>& inputs,
@@ -262,29 +262,6 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
@@ -355,7 +332,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);
// Do copy
copy_inplace(

View File

@@ -11,7 +11,7 @@
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)

View File

@@ -499,8 +499,8 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);
// Do copy
copy_gpu_inplace(

View File

@@ -14,18 +14,7 @@ void slice_gpu(
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
// Calculate out strides and initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
size_t data_end = 1;
for (int i = 0; i < strides.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
slice(in, out, start_indices, strides);
}
void concatenate_gpu(

View File

@@ -1,5 +1,4 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
@@ -49,7 +48,7 @@ void unary_op_gpu_inplace(
} else {
kernel_name = "gn" + std::to_string(work_per_thread);
if (large) {
kernel_name += "_large";
kernel_name += "large";
}
}
concatenate(kernel_name, "_", op, type_to_name(in), type_to_name(out));

View File

@@ -599,7 +599,13 @@ array expand_dims(
namespace {
inline auto
normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
normalize_slice(const Shape& shape, Shape& start, Shape stop, Shape& strides) {
// - Start indices are normalized
// - End indices are unchanged as -1 means something different
// pre-normalization (the end of the axis) versus post normalization (the
// position left of 0).
// - Any strides corresponding to singleton dimension are set to 1
Shape out_shape(shape.size());
bool has_neg_strides = false;
@@ -624,10 +630,10 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
auto ed = e > -1 ? e : -1;
start[i] = st;
stop[i] = ed > st ? st : ed;
ed = ed > st ? st : ed;
auto str = -strides[i];
out_shape[i] = (start[i] - stop[i] + str - 1) / str;
out_shape[i] = (start[i] - ed + str - 1) / str;
} else {
// Clamp to bounds
@@ -635,9 +641,9 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
start[i] = st;
stop[i] = ed < st ? st : ed;
ed = ed < st ? st : ed;
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
out_shape[i] = (ed - start[i] + strides[i] - 1) / strides[i];
}
// Simplify the stride if it's unused
if (out_shape[i] == 1) {

View File

@@ -1229,58 +1229,45 @@ std::vector<array> Convolution::vjp(
in, wt, cotan, kernel_strides_, padding_, stream());
grads.push_back(grad);
} else {
if (flip_) {
auto padding = padding_;
for (int i = 0; i < padding.size(); i++) {
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding[i] = wt_size - padding_[i] - 1;
}
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
auto cotan_trans = group_transpose(cotan, -1, 0, -1);
auto in_trans = swapaxes(in, 0, -1, stream());
auto grad_trans = conv_general(
/* const array& input = */ cotan_trans,
/* const array& weight = */ in_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding,
/* std::vector<int> padding_hi = */ padding,
/* std::vector<int> kernel_dilation = */ input_dilation_,
/* std::vector<int> input_dilation = */ kernel_strides_,
/* int groups = */ groups_,
/* bool flip = */ false,
stream());
if (groups_ > 1) {
grads.push_back(group_transpose(grad_trans, -1, 0, -2));
} else {
grads.push_back(grad_trans);
}
} else {
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
}
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1);
auto grad_trans = conv_general(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,
/* int groups = */ groups_,
/* bool flip = */ false,
stream());
grads.push_back(swapaxes(grad_trans, 0, -1, stream()));
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
}
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1);
auto grad_trans = conv_general(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,
/* int groups = */ groups_,
/* bool flip = */ false,
stream());
if (flip_) {
auto start = Shape(grad_trans.ndim(), 0);
auto stop = Shape(grad_trans.ndim(), 0);
auto strides = Shape(grad_trans.ndim(), 1);
for (int i = 0; i < stop.size(); ++i) {
if (i >= 1 && i < stop.size() - 1) {
start[i] = grad_trans.shape(i);
stop[i] = -start[i] - 1;
strides[i] = -1;
} else {
stop[i] = grad_trans.shape(i);
}
}
grad_trans = slice(grad_trans, start, stop, strides, stream());
}
grads.push_back(swapaxes(grad_trans, 0, -1, stream()));
}
}
}

View File

@@ -1921,7 +1921,6 @@ class Slice : public UnaryPrimitive {
Shape start_indices_;
Shape end_indices_;
Shape strides_;
void eval(const std::vector<array>& inputs, array& out);
};
class SliceUpdate : public UnaryPrimitive {