This commit is contained in:
Awni Hannun 2025-04-21 13:04:39 -07:00 committed by GitHub
parent dc4eada7f0
commit 79b527f45f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 107 additions and 0 deletions

View File

@ -1275,6 +1275,61 @@ std::vector<array> Convolution::vjp(
return grads; return grads;
} }
std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto do_conv = [&](const array& in, const array& w, int groups) {
return conv_general(
in,
w,
kernel_strides_,
padding_,
kernel_dilation_,
input_dilation_,
groups,
flip_,
stream());
};
bool in_vmap = axes[0] >= 0;
bool w_vmap = axes[1] >= 0;
auto in = inputs[0];
auto w = inputs[1];
if (in_vmap && !w_vmap) {
// flatten / unflatten the batch dimension
// of the input / output
if (axes[0] > 0) {
in = moveaxis(in, axes[0], 0, stream());
}
auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_);
out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream());
return {{out}, {0}};
} else if (!in_vmap && w_vmap) {
// flatten into the output channels of w
// unflatten the channels of the output
if (axes[1] > 0) {
w = moveaxis(w, axes[1], 0, stream());
}
auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_);
out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream());
return {{out}, {static_cast<int>(out.ndim() - 2)}};
} else if (in_vmap && w_vmap) {
// use a group convolution when both inputs are vmapped
auto b = in.shape(axes[0]);
in = moveaxis(in, axes[0], -2, stream());
in = flatten(in, -2, -1, stream());
if (axes[1] > 0) {
w = moveaxis(w, axes[1], 0, stream());
}
auto c_out = w.shape(1);
w = flatten(w, 0, 1, stream());
auto out = do_conv(in, w, groups_ * b);
out = unflatten(out, -1, {b, c_out}, stream());
return {{out}, {static_cast<int>(out.ndim() - 2)}};
} else {
return {{do_conv(in, w, groups_)}, {-1}};
}
}
bool Convolution::is_equivalent(const Primitive& other) const { bool Convolution::is_equivalent(const Primitive& other) const {
const Convolution& c_other = static_cast<const Convolution&>(other); const Convolution& c_other = static_cast<const Convolution&>(other);
return padding_ == c_other.padding_ && return padding_ == c_other.padding_ &&

View File

@ -711,6 +711,7 @@ class Convolution : public UnaryPrimitive {
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
DEFINE_VMAP()
DEFINE_PRINT(Convolution) DEFINE_PRINT(Convolution)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
auto state() const { auto state() const {

View File

@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8)) self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6)) self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
def test_vmap_conv(self):
# vmap input only
x = mx.random.uniform(shape=(2, 2, 5, 4))
w = mx.random.uniform(shape=(8, 3, 4))
expected = mx.stack([mx.conv1d(xi, w) for xi in x])
out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w)
self.assertTrue(mx.allclose(expected, out))
x = mx.moveaxis(x, 0, 2)
out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w)
self.assertTrue(mx.allclose(expected, out))
# vmap weights only
x = mx.random.uniform(shape=(2, 5, 4))
w = mx.random.uniform(shape=(3, 8, 3, 4))
expected = mx.stack([mx.conv1d(x, wi) for wi in w])
out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
w = mx.moveaxis(w, 0, 1)
out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w)
self.assertTrue(mx.allclose(expected, out))
# vmap weights and input
x = mx.random.uniform(shape=(3, 2, 5, 4))
w = mx.random.uniform(shape=(3, 8, 3, 4))
expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)])
out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
x = mx.random.uniform(shape=(2, 3, 5, 4))
w = mx.random.uniform(shape=(8, 3, 4, 3))
expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)])
out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w)
self.assertTrue(mx.allclose(expected, out))
# Test with groups
x = mx.random.uniform(shape=(3, 2, 5, 8))
w = mx.random.uniform(shape=(3, 2, 3, 4))
def gconv(x, w):
return mx.conv1d(x, w, groups=2)
expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)])
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
self.assertTrue(mx.allclose(expected, out))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()