mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Fix a couple of slicing bugs (#1827)
* fix a few bugs * fix conv grad * speedup test * comment
This commit is contained in:
parent
9174606d4c
commit
af1b725fda
@ -35,4 +35,29 @@ void shared_buffer_slice(
|
||||
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
void slice(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate out strides, initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
int64_t data_end = 1;
|
||||
for (int i = 0; i < start_indices.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
// data_end can be -1
|
||||
size_t data_size =
|
||||
data_end < 0 ? (data_offset - data_end) : (data_end - data_offset);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -11,11 +11,10 @@ std::tuple<int64_t, Strides> prepare_slice(
|
||||
const Shape& start_indices,
|
||||
const Shape& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
void slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out);
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -86,7 +86,7 @@ void NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Slice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
slice(inputs[0], out, start_indices_, strides_);
|
||||
}
|
||||
void Split::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
@ -262,29 +262,6 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
Strides ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
@ -355,7 +332,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_inplace(
|
||||
|
@ -11,7 +11,7 @@
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
|
||||
instantiate_kernel( \
|
||||
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
@ -499,8 +499,8 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_gpu_inplace(
|
||||
|
@ -14,18 +14,7 @@ void slice_gpu(
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
// Calculate out strides and initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < strides.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
void concatenate_gpu(
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@ -49,7 +48,7 @@ void unary_op_gpu_inplace(
|
||||
} else {
|
||||
kernel_name = "gn" + std::to_string(work_per_thread);
|
||||
if (large) {
|
||||
kernel_name += "_large";
|
||||
kernel_name += "large";
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_", op, type_to_name(in), type_to_name(out));
|
||||
|
16
mlx/ops.cpp
16
mlx/ops.cpp
@ -599,7 +599,13 @@ array expand_dims(
|
||||
namespace {
|
||||
|
||||
inline auto
|
||||
normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
normalize_slice(const Shape& shape, Shape& start, Shape stop, Shape& strides) {
|
||||
// - Start indices are normalized
|
||||
// - End indices are unchanged as -1 means something different
|
||||
// pre-normalization (the end of the axis) versus post normalization (the
|
||||
// position left of 0).
|
||||
// - Any strides corresponding to singleton dimension are set to 1
|
||||
|
||||
Shape out_shape(shape.size());
|
||||
bool has_neg_strides = false;
|
||||
|
||||
@ -624,10 +630,10 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
auto ed = e > -1 ? e : -1;
|
||||
|
||||
start[i] = st;
|
||||
stop[i] = ed > st ? st : ed;
|
||||
ed = ed > st ? st : ed;
|
||||
|
||||
auto str = -strides[i];
|
||||
out_shape[i] = (start[i] - stop[i] + str - 1) / str;
|
||||
out_shape[i] = (start[i] - ed + str - 1) / str;
|
||||
|
||||
} else {
|
||||
// Clamp to bounds
|
||||
@ -635,9 +641,9 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
|
||||
|
||||
start[i] = st;
|
||||
stop[i] = ed < st ? st : ed;
|
||||
ed = ed < st ? st : ed;
|
||||
|
||||
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
|
||||
out_shape[i] = (ed - start[i] + strides[i] - 1) / strides[i];
|
||||
}
|
||||
// Simplify the stride if it's unused
|
||||
if (out_shape[i] == 1) {
|
||||
|
@ -1229,58 +1229,45 @@ std::vector<array> Convolution::vjp(
|
||||
in, wt, cotan, kernel_strides_, padding_, stream());
|
||||
grads.push_back(grad);
|
||||
} else {
|
||||
if (flip_) {
|
||||
auto padding = padding_;
|
||||
for (int i = 0; i < padding.size(); i++) {
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding[i] = wt_size - padding_[i] - 1;
|
||||
}
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
|
||||
auto cotan_trans = group_transpose(cotan, -1, 0, -1);
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ cotan_trans,
|
||||
/* const array& weight = */ in_trans,
|
||||
/* std::vector<int> stride = */ kernel_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding,
|
||||
/* std::vector<int> padding_hi = */ padding,
|
||||
/* std::vector<int> kernel_dilation = */ input_dilation_,
|
||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||
/* int groups = */ groups_,
|
||||
/* bool flip = */ false,
|
||||
stream());
|
||||
if (groups_ > 1) {
|
||||
grads.push_back(group_transpose(grad_trans, -1, 0, -2));
|
||||
} else {
|
||||
grads.push_back(grad_trans);
|
||||
}
|
||||
} else {
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
|
||||
for (int i = 0; i < padding_hi.size(); ++i) {
|
||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
|
||||
}
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto in_trans = group_transpose(in, -1, 0, -1);
|
||||
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ in_trans,
|
||||
/* const array& weight = */ cotan_trans,
|
||||
/* std::vector<int> stride = */ kernel_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding_lo,
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_strides_,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_,
|
||||
/* int groups = */ groups_,
|
||||
/* bool flip = */ false,
|
||||
stream());
|
||||
grads.push_back(swapaxes(grad_trans, 0, -1, stream()));
|
||||
for (int i = 0; i < padding_hi.size(); ++i) {
|
||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
|
||||
}
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto in_trans = group_transpose(in, -1, 0, -1);
|
||||
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ in_trans,
|
||||
/* const array& weight = */ cotan_trans,
|
||||
/* std::vector<int> stride = */ kernel_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding_lo,
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_strides_,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_,
|
||||
/* int groups = */ groups_,
|
||||
/* bool flip = */ false,
|
||||
stream());
|
||||
if (flip_) {
|
||||
auto start = Shape(grad_trans.ndim(), 0);
|
||||
auto stop = Shape(grad_trans.ndim(), 0);
|
||||
auto strides = Shape(grad_trans.ndim(), 1);
|
||||
for (int i = 0; i < stop.size(); ++i) {
|
||||
if (i >= 1 && i < stop.size() - 1) {
|
||||
start[i] = grad_trans.shape(i);
|
||||
stop[i] = -start[i] - 1;
|
||||
strides[i] = -1;
|
||||
} else {
|
||||
stop[i] = grad_trans.shape(i);
|
||||
}
|
||||
}
|
||||
grad_trans = slice(grad_trans, start, stop, strides, stream());
|
||||
}
|
||||
grads.push_back(swapaxes(grad_trans, 0, -1, stream()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1921,7 +1921,6 @@ class Slice : public UnaryPrimitive {
|
||||
Shape start_indices_;
|
||||
Shape end_indices_;
|
||||
Shape strides_;
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class SliceUpdate : public UnaryPrimitive {
|
||||
|
@ -700,6 +700,43 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
|
||||
self.assertTrue(mx.allclose(expected, jout))
|
||||
|
||||
def test_slice_grads(self):
|
||||
# Slice
|
||||
def fun(a):
|
||||
return a[5:-6:-1]
|
||||
|
||||
a = mx.ones(shape=(5,))
|
||||
cotan = mx.random.uniform(shape=(5,))
|
||||
_, (grad,) = mx.vjp(fun, (a,), (cotan,))
|
||||
self.assertTrue(mx.allclose(grad, cotan[::-1]))
|
||||
|
||||
tan = mx.random.uniform(shape=(5,))
|
||||
mx.eval(tan)
|
||||
_, (grad,) = mx.jvp(fun, (a,), (tan,))
|
||||
self.assertTrue(mx.allclose(grad, tan[::-1]))
|
||||
|
||||
# Slice update
|
||||
def fun(a, b):
|
||||
a[4:-5:-2] = b
|
||||
return a
|
||||
|
||||
a = mx.ones(shape=(4,))
|
||||
b = mx.zeros(shape=(2,))
|
||||
|
||||
cotan = mx.random.uniform(shape=(4,))
|
||||
_, (grad_a, grad_b) = mx.vjp(fun, (a, b), (cotan,))
|
||||
expected_a = mx.array(cotan)
|
||||
expected_a[1::2] = 0.0
|
||||
self.assertTrue(mx.allclose(grad_a, expected_a))
|
||||
self.assertTrue(mx.allclose(grad_b, cotan[4:-5:-2]))
|
||||
|
||||
tan_a = mx.random.uniform(shape=(4,))
|
||||
tan_b = mx.random.uniform(shape=(2,))
|
||||
_, (grad,) = mx.jvp(fun, (a, b), (tan_a, tan_b))
|
||||
expected = tan_a
|
||||
expected[4:-5:-2] = tan_b
|
||||
self.assertTrue(mx.allclose(grad, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -911,6 +911,44 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
||||
self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# Test with input dilation
|
||||
inputs = mx.random.normal((1, 14, 14, 2))
|
||||
kernel = mx.random.normal((2, 7, 7, 2))
|
||||
|
||||
def conv_flip(kernel):
|
||||
return mx.conv_general(
|
||||
inputs,
|
||||
kernel,
|
||||
stride=1,
|
||||
padding=([6, 6], [15, 15]),
|
||||
kernel_dilation=(1, 1),
|
||||
input_dilation=(16, 16),
|
||||
groups=1,
|
||||
flip=True,
|
||||
).sum()
|
||||
|
||||
def reverse_sequence(xs, axis=0):
|
||||
indices = mx.arange(xs.shape[axis] - 1, -1, -1)
|
||||
return mx.take(xs, indices, axis=axis)
|
||||
|
||||
def conv_manual_flip(kernel):
|
||||
for ax in range(1, kernel.ndim - 1):
|
||||
kernel = reverse_sequence(kernel, axis=ax)
|
||||
return mx.conv_general(
|
||||
inputs,
|
||||
kernel,
|
||||
stride=1,
|
||||
padding=([6, 6], [15, 15]),
|
||||
kernel_dilation=(1, 1),
|
||||
input_dilation=(16, 16),
|
||||
groups=1,
|
||||
flip=False,
|
||||
).sum()
|
||||
|
||||
grad = mx.grad(conv_flip)(kernel)
|
||||
expected_grad = mx.grad(conv_manual_flip)(kernel)
|
||||
self.assertTrue(mx.allclose(grad, expected_grad))
|
||||
|
||||
def test_conv_groups_grad(self):
|
||||
def fn(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
|
@ -587,10 +587,10 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
|
||||
((7, 7, 7), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
|
||||
((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
|
||||
((7, 7, 7), (5, 5, 5), (3, 3, 3), (2, 2, 2), (3, 2, 2)),
|
||||
((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
|
||||
):
|
||||
run_conv_transpose3D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
|
@ -2816,6 +2816,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(a.shape, (3, 4, 2))
|
||||
self.assertEqual(b.shape, (3, 4, 2))
|
||||
|
||||
def test_slice_update_reversed(self):
|
||||
a = mx.array([1, 2, 3, 4])
|
||||
b = a[::-1]
|
||||
b[::2] = 0
|
||||
self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user