diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 590af60f6..3d36f0881 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1275,6 +1275,61 @@ std::vector Convolution::vjp( return grads; } +std::pair, std::vector> Convolution::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto do_conv = [&](const array& in, const array& w, int groups) { + return conv_general( + in, + w, + kernel_strides_, + padding_, + kernel_dilation_, + input_dilation_, + groups, + flip_, + stream()); + }; + bool in_vmap = axes[0] >= 0; + bool w_vmap = axes[1] >= 0; + auto in = inputs[0]; + auto w = inputs[1]; + if (in_vmap && !w_vmap) { + // flatten / unflatten the batch dimension + // of the input / output + if (axes[0] > 0) { + in = moveaxis(in, axes[0], 0, stream()); + } + auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_); + out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream()); + return {{out}, {0}}; + } else if (!in_vmap && w_vmap) { + // flatten into the output channels of w + // unflatten the channels of the output + if (axes[1] > 0) { + w = moveaxis(w, axes[1], 0, stream()); + } + auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_); + out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream()); + return {{out}, {static_cast(out.ndim() - 2)}}; + } else if (in_vmap && w_vmap) { + // use a group convolution when both inputs are vmapped + auto b = in.shape(axes[0]); + in = moveaxis(in, axes[0], -2, stream()); + in = flatten(in, -2, -1, stream()); + if (axes[1] > 0) { + w = moveaxis(w, axes[1], 0, stream()); + } + auto c_out = w.shape(1); + w = flatten(w, 0, 1, stream()); + auto out = do_conv(in, w, groups_ * b); + out = unflatten(out, -1, {b, c_out}, stream()); + return {{out}, {static_cast(out.ndim() - 2)}}; + } else { + return {{do_conv(in, w, groups_)}, {-1}}; + } +} + bool Convolution::is_equivalent(const Primitive& other) const { const Convolution& c_other = static_cast(other); return padding_ == c_other.padding_ && diff --git a/mlx/primitives.h b/mlx/primitives.h index 997931f30..3753e43c5 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -711,6 +711,7 @@ class Convolution : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; + DEFINE_VMAP() DEFINE_PRINT(Convolution) bool is_equivalent(const Primitive& other) const override; auto state() const { diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 1a1ba23b3..e571678d3 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8)) self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6)) + def test_vmap_conv(self): + # vmap input only + x = mx.random.uniform(shape=(2, 2, 5, 4)) + w = mx.random.uniform(shape=(8, 3, 4)) + + expected = mx.stack([mx.conv1d(xi, w) for xi in x]) + out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + x = mx.moveaxis(x, 0, 2) + out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # vmap weights only + x = mx.random.uniform(shape=(2, 5, 4)) + w = mx.random.uniform(shape=(3, 8, 3, 4)) + + expected = mx.stack([mx.conv1d(x, wi) for wi in w]) + out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + w = mx.moveaxis(w, 0, 1) + out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # vmap weights and input + x = mx.random.uniform(shape=(3, 2, 5, 4)) + w = mx.random.uniform(shape=(3, 8, 3, 4)) + + expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)]) + out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + x = mx.random.uniform(shape=(2, 3, 5, 4)) + w = mx.random.uniform(shape=(8, 3, 4, 3)) + + expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)]) + out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # Test with groups + x = mx.random.uniform(shape=(3, 2, 5, 8)) + w = mx.random.uniform(shape=(3, 2, 3, 4)) + + def gconv(x, w): + return mx.conv1d(x, w, groups=2) + + expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)]) + out = mx.vmap(gconv, in_axes=(0, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + if __name__ == "__main__": unittest.main()