From bdb36c9a63051be901f07984c558c9813c8007e2 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Sun, 7 Jul 2024 21:34:59 -0700 Subject: [PATCH] add zero vjps for bitwise ops and gather w.r.t. index (#1256) --- mlx/primitives.cpp | 40 +++++++++++++++++++++++++++++++------- mlx/primitives.h | 1 + python/tests/test_array.py | 13 +++++++++++++ python/tests/test_ops.py | 15 ++++++++++++++ 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b5cd9ac25..94136c6a5 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -523,7 +523,7 @@ std::vector AsType::vjp( const std::vector&) { if (cotangents[0].dtype() != dtype_) { throw std::invalid_argument( - "[astype] Type of cotangentsgent does not much primal output type."); + "[astype] Type of cotangents does not match primal output type."); } return {astype(cotangents[0], primals[0].dtype(), stream())}; } @@ -629,6 +629,26 @@ std::pair, std::vector> BitwiseBinary::vmap( {to_ax}}; } +std::vector BitwiseBinary::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 2); + std::vector vjps = {zeros_like(tangents[0], stream())}; + if (argnums.size() > 1) { + vjps.push_back(vjps.back()); + } + return vjps; +} + +std::vector BitwiseBinary::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + return jvp(primals, cotangents, argnums); +} + std::vector Broadcast::vjp( const std::vector& primals, const std::vector& cotangents, @@ -1598,13 +1618,19 @@ std::vector Gather::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - if (argnums.size() > 1 || argnums[0] != 0) { - throw std::invalid_argument( - "[gather] Cannot calculate VJP with respect to indices."); + std::vector vjps; + for (int argnum : argnums) { + if (argnum > 0) { + // Grads w.r.t. indices are zero + vjps.push_back( + zeros(primals[argnum].shape(), primals[argnum].dtype(), stream())); + } else { + auto src = zeros_like(primals[0], stream()); + std::vector inds(primals.begin() + 1, primals.end()); + vjps.push_back(scatter_add(src, inds, cotangents[0], axes_, stream())); + } } - auto src = zeros_like(primals[0], stream()); - std::vector inds(primals.begin() + 1, primals.end()); - return {scatter_add(src, inds, cotangents[0], axes_, stream())}; + return vjps; } std::vector Gather::jvp( diff --git a/mlx/primitives.h b/mlx/primitives.h index 8f49a4c1d..4bd3b421d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -471,6 +471,7 @@ class BitwiseBinary : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() + DEFINE_GRADS() bool is_equivalent(const Primitive& other) const override; void print(std::ostream& os) override; DEFINE_INPUT_OUTPUT_SHAPE() diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 67e679c35..7e9b68fa5 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1043,6 +1043,19 @@ class TestArray(mlx_tests.MLXTestCase): a_mlx = mx.array(a_np) self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0]))) + def test_indexing_grad(self): + x = mx.array([[1, 2], [3, 4]]).astype(mx.float32) + ind = mx.array([0, 1, 0]).astype(mx.float32) + + def index_fn(x, ind): + return x[ind.astype(mx.int32)].sum() + + grad_x, grad_ind = mx.grad(index_fn, argnums=(0, 1))(x, ind) + expected = mx.array([[2, 2], [1, 1]]) + + self.assertTrue(mx.array_equal(grad_x, expected)) + self.assertTrue(mx.array_equal(grad_ind, mx.zeros(ind.shape))) + def test_setitem(self): a = mx.array(0) a[None] = 1 diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index dace9c9aa..9586245e6 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2373,6 +2373,21 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(c.shape, (3, 2, 5)) self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_))) + def test_bitwise_grad(self): + a = np.random.randint(0, 10, size=(4, 3)) + b = np.random.randint(0, 10, size=(4, 3)) + cotangent = np.random.randint(0, 10, size=(4, 3)) + a = mx.array(a) + b = mx.array(b) + cotangent = mx.array(cotangent) + + def bitwise(a, b): + return a.astype(mx.int32) & b.astype(mx.int32) + + _, vjps = mx.vjp(bitwise, [a, b], [cotangent]) + for vjp in vjps: + self.assertFalse(np.any(np.array(vjp))) + def test_conjugate(self): shape = (3, 5, 7) a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)