diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index 9a51aefaf..c446ff948 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -35,4 +35,29 @@ void shared_buffer_slice( move_or_copy(in, out, out_strides, flags, data_size, data_offset); } +void slice( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + // Calculate out strides, initial offset + auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides); + int64_t data_end = 1; + for (int i = 0; i < start_indices.size(); ++i) { + if (in.shape()[i] > 1) { + auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1; + data_end += end_idx * in.strides()[i]; + } + } + // data_end can be -1 + size_t data_size = + data_end < 0 ? (data_offset - data_end) : (data_end - data_offset); + shared_buffer_slice(in, inp_strides, data_offset, data_size, out); +} + } // namespace mlx::core diff --git a/mlx/backend/common/slicing.h b/mlx/backend/common/slicing.h index eda37320d..b667d2619 100644 --- a/mlx/backend/common/slicing.h +++ b/mlx/backend/common/slicing.h @@ -11,11 +11,10 @@ std::tuple prepare_slice( const Shape& start_indices, const Shape& strides); -void shared_buffer_slice( +void slice( const array& in, - const Strides& out_strides, - size_t data_offset, - size_t data_size, - array& out); + array& out, + const Shape& start_indices, + const Shape& strides); } // namespace mlx::core diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index 8ae2cc520..b5d9d7ef3 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -86,7 +86,7 @@ void NumberOfElements::eval_cpu(const std::vector& inputs, array& out) { eval(inputs, out); } void Slice::eval_cpu(const std::vector& inputs, array& out) { - eval(inputs, out); + slice(inputs[0], out, start_indices_, strides_); } void Split::eval_cpu( const std::vector& inputs, @@ -262,29 +262,6 @@ void Reshape::eval_cpu(const std::vector& inputs, array& out) { reshape(inputs[0], out); } -void Slice::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - - // Calculate out strides, initial offset and if copy needs to be made - auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_); - size_t data_end = 1; - for (int i = 0; i < end_indices_.size(); ++i) { - if (in.shape()[i] > 1) { - auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1; - data_end += end_idx * in.strides()[i]; - } - } - size_t data_size = data_end - data_offset; - Strides ostrides{inp_strides.begin(), inp_strides.end()}; - shared_buffer_slice(in, ostrides, data_offset, data_size, out); -} - void DynamicSlice::eval_cpu(const std::vector& inputs, array& out) { if (out.size() == 0) { out.set_data(nullptr); @@ -355,7 +332,8 @@ void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype); // Calculate out strides, initial offset and if copy needs to be made - auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_); + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); // Do copy copy_inplace( diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 29220c53b..c64e05931 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -11,7 +11,7 @@ instantiate_kernel( \ "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ instantiate_kernel( \ - "gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) + "gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) #define instantiate_unary_all_same(op, tname, type) \ instantiate_unary_all(op, tname, tname, type, type) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 47190daf3..627f30478 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -499,8 +499,8 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { ? CopyType::Vector : CopyType::General; copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); - - auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_); + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); // Do copy copy_gpu_inplace( diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 3493ea858..d34e1a747 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -14,18 +14,7 @@ void slice_gpu( const Shape& start_indices, const Shape& strides, const Stream& s) { - // Calculate out strides and initial offset - auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides); - - size_t data_end = 1; - for (int i = 0; i < strides.size(); ++i) { - if (in.shape()[i] > 1) { - auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1; - data_end += end_idx * in.strides()[i]; - } - } - size_t data_size = data_end - data_offset; - shared_buffer_slice(in, inp_strides, data_offset, data_size, out); + slice(in, out, start_indices, strides); } void concatenate_gpu( diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 1496bcb5d..77f82b8fc 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -1,5 +1,4 @@ // Copyright © 2024 Apple Inc. - #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" @@ -49,7 +48,7 @@ void unary_op_gpu_inplace( } else { kernel_name = "gn" + std::to_string(work_per_thread); if (large) { - kernel_name += "_large"; + kernel_name += "large"; } } concatenate(kernel_name, "_", op, type_to_name(in), type_to_name(out)); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 8853d8351..03785c11d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -599,7 +599,13 @@ array expand_dims( namespace { inline auto -normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) { +normalize_slice(const Shape& shape, Shape& start, Shape stop, Shape& strides) { + // - Start indices are normalized + // - End indices are unchanged as -1 means something different + // pre-normalization (the end of the axis) versus post normalization (the + // position left of 0). + // - Any strides corresponding to singleton dimension are set to 1 + Shape out_shape(shape.size()); bool has_neg_strides = false; @@ -624,10 +630,10 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) { auto ed = e > -1 ? e : -1; start[i] = st; - stop[i] = ed > st ? st : ed; + ed = ed > st ? st : ed; auto str = -strides[i]; - out_shape[i] = (start[i] - stop[i] + str - 1) / str; + out_shape[i] = (start[i] - ed + str - 1) / str; } else { // Clamp to bounds @@ -635,9 +641,9 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) { auto ed = std::max(static_cast(0), std::min(e, n)); start[i] = st; - stop[i] = ed < st ? st : ed; + ed = ed < st ? st : ed; - out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i]; + out_shape[i] = (ed - start[i] + strides[i] - 1) / strides[i]; } // Simplify the stride if it's unused if (out_shape[i] == 1) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4cde84831..90ae57906 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1229,58 +1229,45 @@ std::vector Convolution::vjp( in, wt, cotan, kernel_strides_, padding_, stream()); grads.push_back(grad); } else { - if (flip_) { - auto padding = padding_; - for (int i = 0; i < padding.size(); i++) { - int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding[i] = wt_size - padding_[i] - 1; - } + std::vector padding_lo = padding_; + std::vector padding_hi = padding_; - auto cotan_trans = group_transpose(cotan, -1, 0, -1); - auto in_trans = swapaxes(in, 0, -1, stream()); - - auto grad_trans = conv_general( - /* const array& input = */ cotan_trans, - /* const array& weight = */ in_trans, - /* std::vector stride = */ kernel_dilation_, - /* std::vector padding_lo = */ padding, - /* std::vector padding_hi = */ padding, - /* std::vector kernel_dilation = */ input_dilation_, - /* std::vector input_dilation = */ kernel_strides_, - /* int groups = */ groups_, - /* bool flip = */ false, - stream()); - if (groups_ > 1) { - grads.push_back(group_transpose(grad_trans, -1, 0, -2)); - } else { - grads.push_back(grad_trans); - } - } else { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; - - for (int i = 0; i < padding_hi.size(); ++i) { - int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); - int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; - } - auto cotan_trans = swapaxes(cotan, 0, -1, stream()); - auto in_trans = group_transpose(in, -1, 0, -1); - - auto grad_trans = conv_general( - /* const array& input = */ in_trans, - /* const array& weight = */ cotan_trans, - /* std::vector stride = */ kernel_dilation_, - /* std::vector padding_lo = */ padding_lo, - /* std::vector padding_hi = */ padding_hi, - /* std::vector kernel_dilation = */ kernel_strides_, - /* std::vector input_dilation = */ input_dilation_, - /* int groups = */ groups_, - /* bool flip = */ false, - stream()); - grads.push_back(swapaxes(grad_trans, 0, -1, stream())); + for (int i = 0; i < padding_hi.size(); ++i) { + int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); + int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); + int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); + padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; } + auto cotan_trans = swapaxes(cotan, 0, -1, stream()); + auto in_trans = group_transpose(in, -1, 0, -1); + + auto grad_trans = conv_general( + /* const array& input = */ in_trans, + /* const array& weight = */ cotan_trans, + /* std::vector stride = */ kernel_dilation_, + /* std::vector padding_lo = */ padding_lo, + /* std::vector padding_hi = */ padding_hi, + /* std::vector kernel_dilation = */ kernel_strides_, + /* std::vector input_dilation = */ input_dilation_, + /* int groups = */ groups_, + /* bool flip = */ false, + stream()); + if (flip_) { + auto start = Shape(grad_trans.ndim(), 0); + auto stop = Shape(grad_trans.ndim(), 0); + auto strides = Shape(grad_trans.ndim(), 1); + for (int i = 0; i < stop.size(); ++i) { + if (i >= 1 && i < stop.size() - 1) { + start[i] = grad_trans.shape(i); + stop[i] = -start[i] - 1; + strides[i] = -1; + } else { + stop[i] = grad_trans.shape(i); + } + } + grad_trans = slice(grad_trans, start, stop, strides, stream()); + } + grads.push_back(swapaxes(grad_trans, 0, -1, stream())); } } } diff --git a/mlx/primitives.h b/mlx/primitives.h index b1f55f9ac..fed5b4988 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1921,7 +1921,6 @@ class Slice : public UnaryPrimitive { Shape start_indices_; Shape end_indices_; Shape strides_; - void eval(const std::vector& inputs, array& out); }; class SliceUpdate : public UnaryPrimitive { diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index b3281bd3e..3c226365f 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -700,6 +700,43 @@ class TestAutograd(mlx_tests.MLXTestCase): expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None] self.assertTrue(mx.allclose(expected, jout)) + def test_slice_grads(self): + # Slice + def fun(a): + return a[5:-6:-1] + + a = mx.ones(shape=(5,)) + cotan = mx.random.uniform(shape=(5,)) + _, (grad,) = mx.vjp(fun, (a,), (cotan,)) + self.assertTrue(mx.allclose(grad, cotan[::-1])) + + tan = mx.random.uniform(shape=(5,)) + mx.eval(tan) + _, (grad,) = mx.jvp(fun, (a,), (tan,)) + self.assertTrue(mx.allclose(grad, tan[::-1])) + + # Slice update + def fun(a, b): + a[4:-5:-2] = b + return a + + a = mx.ones(shape=(4,)) + b = mx.zeros(shape=(2,)) + + cotan = mx.random.uniform(shape=(4,)) + _, (grad_a, grad_b) = mx.vjp(fun, (a, b), (cotan,)) + expected_a = mx.array(cotan) + expected_a[1::2] = 0.0 + self.assertTrue(mx.allclose(grad_a, expected_a)) + self.assertTrue(mx.allclose(grad_b, cotan[4:-5:-2])) + + tan_a = mx.random.uniform(shape=(4,)) + tan_b = mx.random.uniform(shape=(2,)) + _, (grad,) = mx.jvp(fun, (a, b), (tan_a, tan_b)) + expected = tan_a + expected[4:-5:-2] = tan_b + self.assertTrue(mx.allclose(grad, expected)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 79324829b..862e8ec7f 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -911,6 +911,44 @@ class TestConv(mlx_tests.MLXTestCase): expected = mx.array([[dw00, dw01], [dw10, dw11]]) self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5)) + # Test with input dilation + inputs = mx.random.normal((1, 14, 14, 2)) + kernel = mx.random.normal((2, 7, 7, 2)) + + def conv_flip(kernel): + return mx.conv_general( + inputs, + kernel, + stride=1, + padding=([6, 6], [15, 15]), + kernel_dilation=(1, 1), + input_dilation=(16, 16), + groups=1, + flip=True, + ).sum() + + def reverse_sequence(xs, axis=0): + indices = mx.arange(xs.shape[axis] - 1, -1, -1) + return mx.take(xs, indices, axis=axis) + + def conv_manual_flip(kernel): + for ax in range(1, kernel.ndim - 1): + kernel = reverse_sequence(kernel, axis=ax) + return mx.conv_general( + inputs, + kernel, + stride=1, + padding=([6, 6], [15, 15]), + kernel_dilation=(1, 1), + input_dilation=(16, 16), + groups=1, + flip=False, + ).sum() + + grad = mx.grad(conv_flip)(kernel) + expected_grad = mx.grad(conv_manual_flip)(kernel) + self.assertTrue(mx.allclose(grad, expected_grad)) + def test_conv_groups_grad(self): def fn(x, w): num_groups = x.shape[-1] // w.shape[-1] diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py index 7b458914d..1ac20cbb1 100644 --- a/python/tests/test_conv_transpose.py +++ b/python/tests/test_conv_transpose.py @@ -587,10 +587,10 @@ class TestConvTranspose(mlx_tests.MLXTestCase): for idim, kdim, stride, padding, dilation in ( ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)), ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)), - ((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)), - ((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)), - ((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)), - ((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)), + ((7, 7, 7), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)), + ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)), + ((7, 7, 7), (5, 5, 5), (3, 3, 3), (2, 2, 2), (3, 2, 2)), + ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)), ): run_conv_transpose3D_grad( N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 31a65f524..ea798c18c 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2816,6 +2816,12 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(a.shape, (3, 4, 2)) self.assertEqual(b.shape, (3, 4, 2)) + def test_slice_update_reversed(self): + a = mx.array([1, 2, 3, 4]) + b = a[::-1] + b[::2] = 0 + self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1]))) + if __name__ == "__main__": unittest.main()