diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 65ed3006f..e0d70ea14 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -121,6 +121,7 @@ Operations pad power prod + put_along_axis quantize quantized_matmul radians diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 200afd9dc..f69943cd8 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2767,6 +2767,53 @@ array take_along_axis( return reshape(out, out_shape, s); } +array put_along_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s /* = {} */) { + if (axis + a.ndim() < 0 || axis >= static_cast(a.ndim())) { + std::ostringstream msg; + msg << "[put_along_axis] Received invalid axis " << " for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (indices.ndim() != a.ndim()) { + std::ostringstream msg; + msg << "[put_along_axis] Indices of dimension " << indices.ndim() + << " does not match array of dimension " << a.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + // Allow negative axis + axis = axis < 0 ? a.ndim() + axis : axis; + + std::vector nd_indices; + std::vector index_shape(a.ndim(), 1); + for (int i = 0; i < a.ndim(); ++i) { + if (i == axis) { + nd_indices.push_back(indices); + } else { + // Reshape so they can be broadcast + index_shape[i] = a.shape(i); + nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s)); + index_shape[i] = 1; + } + } + + auto update = astype(broadcast_to(values, indices.shape(), s), a.dtype(), s); + { + auto update_shape = update.shape(); + update_shape.resize(update_shape.size() + a.ndim(), 1); + update = reshape(update, std::move(update_shape), s); + } + std::vector dims(a.ndim()); + std::iota(dims.begin(), dims.end(), 0); + return scatter(a, nd_indices, update, dims, s); +} + /** Scatter updates to given indices */ array scatter( const array& a, @@ -2853,7 +2900,6 @@ array scatter( } inputs.insert(inputs.begin(), a); - // TODO promote or cast? inputs.push_back(astype(updates, a.dtype(), s)); return array( diff --git a/mlx/ops.h b/mlx/ops.h index 42445ccb6..b0c093f47 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -947,6 +947,14 @@ array take_along_axis( int axis, StreamOrDevice s = {}); +/** Put the values into the array at the given indices along the axis */ +array put_along_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s = {}); + /** Scatter updates to the given indices. * * The parameters ``indices`` and ``axes`` determine the locations of ``a`` diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index a1549fa6f..c28a945a3 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -471,6 +471,21 @@ std::pair, std::vector> ArgPartition::vmap( return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes}; } +std::vector ArgPartition::vjp( + const std::vector& primals, + const std::vector&, + const std::vector&, + const std::vector&) { + return {zeros_like(primals[0], stream())}; +} + +std::vector ArgPartition::jvp( + const std::vector&, + const std::vector& tangents, + const std::vector&) { + return {zeros_like(tangents[0], stream())}; +} + bool ArgPartition::is_equivalent(const Primitive& other) const { const ArgPartition& r_other = static_cast(other); return axis_ == r_other.axis_ && kth_ == r_other.kth_; @@ -495,6 +510,21 @@ std::pair, std::vector> ArgReduce::vmap( return {out, axes}; } +std::vector ArgReduce::vjp( + const std::vector& primals, + const std::vector&, + const std::vector&, + const std::vector&) { + return {zeros_like(primals[0], stream())}; +} + +std::vector ArgReduce::jvp( + const std::vector&, + const std::vector& tangents, + const std::vector&) { + return {zeros_like(tangents[0], stream())}; +} + std::pair, std::vector> ArgSort::vmap( const std::vector& inputs, const std::vector& axes) { @@ -2336,7 +2366,13 @@ std::vector Partition::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + auto sort_idx = argpartition(primals[0], kth_, axis_, stream()); + return {put_along_axis( + zeros_like(primals[0], stream()), + sort_idx, + cotangents[0], + axis_, + stream())}; } std::vector Partition::jvp( diff --git a/mlx/primitives.h b/mlx/primitives.h index 5e5bda7c0..810eb5096 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -357,6 +357,7 @@ class ArgPartition : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() + DEFINE_GRADS() DEFINE_PRINT(ArgPartition) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -382,6 +383,7 @@ class ArgReduce : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() + DEFINE_GRADS() DEFINE_PRINT(ArgReduce) bool is_equivalent(const Primitive& other) const override; std::vector> output_shapes( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index ba814f1a2..e1117786b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1463,7 +1463,48 @@ void init_ops(nb::module_& m) { operation. Returns: - array: The output array with the specified shape and values. + array: The output array. + )pbdoc"); + m.def( + "put_along_axis", + [](const array& a, + const array& indices, + const array& values, + const std::optional& axis, + StreamOrDevice s) { + if (axis.has_value()) { + return put_along_axis(a, indices, values, axis.value(), s); + } else { + return reshape( + put_along_axis(reshape(a, {-1}, s), indices, values, 0, s), + a.shape(), + s); + } + }, + nb::arg(), + "indices"_a, + "values"_a, + "axis"_a.none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def put_along_axis(a: array, /, indices: array, values: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Put values along an axis at the specified indices. + + Args: + a (array): Destination array. + indices (array): Indices array. These should be broadcastable with + the input array excluding the `axis` dimension. + values (array): Values array. These should be broadcastable with + the indices. + + axis (int or None): Axis in the destination to put the values to. If + ``axis == None`` the destination is flattened prior to the put + operation. + + Returns: + array: The output array. )pbdoc"); m.def( "full", diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 7db0d49ff..1e99c3825 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -496,6 +496,16 @@ class TestAutograd(mlx_tests.MLXTestCase): expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0]) self.assertTrue(mx.allclose(out, expected)) + def test_topk_grad(self): + a = mx.array([[1, 2, 6, 4, 5], [9, 5, 6, 7, 8]], mx.float32) + + def fun(x): + return mx.topk(x, 2) + + out = mx.vjp(fun, (a,), (mx.ones((2, 2)),))[1][0] + expected = mx.array([[0, 0, 1, 0, 1], [1, 0, 0, 0, 1]], mx.float32) + self.assertTrue(mx.array_equal(out, expected)) + def test_custom_function(self): # Make a custom function my_exp = mx.custom_function(mx.exp) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 7dcde16d7..7a9404f27 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1075,6 +1075,31 @@ class TestOps(mlx_tests.MLXTestCase): out_mlx = mx.take_along_axis(a_mlx, mx.reshape(idx_mlx, shape), axis=ax) self.assertTrue(np.array_equal(out_np, np.array(out_mlx))) + def test_put_along_axis(self): + for ax in [None, 0, 1, 2]: + + a_np = np.arange(16).reshape(2, 2, 4).astype(np.int32) + a_mlx = mx.array(a_np) + + if ax == None: + idx_np = np.random.randint(low=0, high=a_np.size, size=(16,)) + values_np = np.random.randint(low=0, high=100, size=(16,)) + else: + shape = list(a_np.shape) + shape[ax] = 2 + idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape) + values_np = np.random.randint(low=0, high=100, size=shape) + + idx_np.astype(np.int32) + values_np.astype(a_np.dtype) + + idx_mlx = mx.array(idx_np) + values_mlx = mx.array(values_np) + + np.put_along_axis(a_np, idx_np, values_np, axis=ax) + out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax) + self.assertTrue(np.array_equal(a_np, out_mlx)) + def test_split(self): a = mx.array([1, 2, 3]) splits = mx.split(a, 3) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index fb6507267..6eec3e99c 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1983,6 +1983,12 @@ TEST_CASE("test take") { CHECK(array_equal(out, zeros({1, 1, 1})).item()); out = take(a, array({0, 1}), 1); CHECK(array_equal(out, zeros({1, 2, 1})).item()); + + // Indices have wrong shape + a = zeros({2, 3, 4}); + CHECK_THROWS(take(a, zeros({1, 3, 4}), 1)); + CHECK_THROWS(take(a, zeros({2, 3, 7}), 1)); + CHECK_THROWS(take(a, zeros({2, 3, 2}), 0)); } TEST_CASE("test take along axis") { @@ -2001,12 +2007,6 @@ TEST_CASE("test take along axis") { out = take_along_axis(a, array({1}), -1); CHECK_EQ(out.item(), 1); - // Indices have wrong shape - a = zeros({2, 3, 4}); - CHECK_THROWS(take(a, zeros({1, 3, 4}), 1)); - CHECK_THROWS(take(a, zeros({2, 3, 7}), 1)); - CHECK_THROWS(take(a, zeros({2, 3, 2}), 0)); - // Empty arrays a = reshape(array({}), {1, 0}); CHECK_THROWS(take_along_axis(a, array({1}), 0)); @@ -2057,6 +2057,48 @@ TEST_CASE("test take along axis") { .item()); } +TEST_CASE("test put along axis") { + // No zero dim arrays + auto a = array(1); + auto v = array(1); + CHECK_THROWS(put_along_axis(a, array(0), v, 0)); + + // Index and array size mismatches + a = arange(5); + CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 1)); + CHECK_THROWS(put_along_axis(a, array({1}, {1, 1}), array({0}), 0)); + CHECK_THROWS(put_along_axis(a, array(1), array(0), -1)); + + auto expected = array({0, 0, 2, 3, 4}); + auto out = put_along_axis(a, array({1}), array({0}), 0); + CHECK(array_equal(out, expected).item()); + + // Empty arrays + a = reshape(array({}), {1, 0}); + CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 0)); + + auto inds = reshape(astype(array({}), int32), {1, 0}); + out = take_along_axis(a, inds, 0); + eval(out); // Make sure it runs + CHECK_EQ(out.shape(), std::vector{1, 0}); + + a = array({1, 2, 3, 4}, {2, 2}); + inds = array({0, 1}, {1, 2}); + out = put_along_axis(a, inds, array({0}), 0); + expected = array({0, 2, 3, 0}, {2, 2}); + CHECK(array_equal(out, expected).item()); + + inds = array({0, 0, 1, 1}, {2, 2}, int32); + auto values = array({2, 3, 4, 5}, {2, 2}, int32); + out = put_along_axis(a, inds, values, 0); + CHECK(array_equal(out, array({2, 3, 4, 5}, {2, 2})).item()); + + inds = array({0, 1}, {2, 1}); + out = put_along_axis(a, inds, array({0}), 1); + expected = array({0, 2, 3, 0}, {2, 2}); + CHECK(array_equal(out, expected).item()); +} + TEST_CASE("test scatter") { // More indices than dimensions CHECK_THROWS(scatter(array(0), array({1}), array(1), 0));