mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
conv vmap (#2102)
This commit is contained in:
parent
dc4eada7f0
commit
79b527f45f
@ -1275,6 +1275,61 @@ std::vector<array> Convolution::vjp(
|
||||
return grads;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto do_conv = [&](const array& in, const array& w, int groups) {
|
||||
return conv_general(
|
||||
in,
|
||||
w,
|
||||
kernel_strides_,
|
||||
padding_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups,
|
||||
flip_,
|
||||
stream());
|
||||
};
|
||||
bool in_vmap = axes[0] >= 0;
|
||||
bool w_vmap = axes[1] >= 0;
|
||||
auto in = inputs[0];
|
||||
auto w = inputs[1];
|
||||
if (in_vmap && !w_vmap) {
|
||||
// flatten / unflatten the batch dimension
|
||||
// of the input / output
|
||||
if (axes[0] > 0) {
|
||||
in = moveaxis(in, axes[0], 0, stream());
|
||||
}
|
||||
auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_);
|
||||
out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream());
|
||||
return {{out}, {0}};
|
||||
} else if (!in_vmap && w_vmap) {
|
||||
// flatten into the output channels of w
|
||||
// unflatten the channels of the output
|
||||
if (axes[1] > 0) {
|
||||
w = moveaxis(w, axes[1], 0, stream());
|
||||
}
|
||||
auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_);
|
||||
out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream());
|
||||
return {{out}, {static_cast<int>(out.ndim() - 2)}};
|
||||
} else if (in_vmap && w_vmap) {
|
||||
// use a group convolution when both inputs are vmapped
|
||||
auto b = in.shape(axes[0]);
|
||||
in = moveaxis(in, axes[0], -2, stream());
|
||||
in = flatten(in, -2, -1, stream());
|
||||
if (axes[1] > 0) {
|
||||
w = moveaxis(w, axes[1], 0, stream());
|
||||
}
|
||||
auto c_out = w.shape(1);
|
||||
w = flatten(w, 0, 1, stream());
|
||||
auto out = do_conv(in, w, groups_ * b);
|
||||
out = unflatten(out, -1, {b, c_out}, stream());
|
||||
return {{out}, {static_cast<int>(out.ndim() - 2)}};
|
||||
} else {
|
||||
return {{do_conv(in, w, groups_)}, {-1}};
|
||||
}
|
||||
}
|
||||
|
||||
bool Convolution::is_equivalent(const Primitive& other) const {
|
||||
const Convolution& c_other = static_cast<const Convolution&>(other);
|
||||
return padding_ == c_other.padding_ &&
|
||||
|
@ -711,6 +711,7 @@ class Convolution : public UnaryPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Convolution)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
|
@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
|
||||
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
|
||||
|
||||
def test_vmap_conv(self):
|
||||
# vmap input only
|
||||
x = mx.random.uniform(shape=(2, 2, 5, 4))
|
||||
w = mx.random.uniform(shape=(8, 3, 4))
|
||||
|
||||
expected = mx.stack([mx.conv1d(xi, w) for xi in x])
|
||||
out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
x = mx.moveaxis(x, 0, 2)
|
||||
out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# vmap weights only
|
||||
x = mx.random.uniform(shape=(2, 5, 4))
|
||||
w = mx.random.uniform(shape=(3, 8, 3, 4))
|
||||
|
||||
expected = mx.stack([mx.conv1d(x, wi) for wi in w])
|
||||
out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
w = mx.moveaxis(w, 0, 1)
|
||||
out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# vmap weights and input
|
||||
x = mx.random.uniform(shape=(3, 2, 5, 4))
|
||||
w = mx.random.uniform(shape=(3, 8, 3, 4))
|
||||
|
||||
expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)])
|
||||
out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 5, 4))
|
||||
w = mx.random.uniform(shape=(8, 3, 4, 3))
|
||||
|
||||
expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)])
|
||||
out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Test with groups
|
||||
x = mx.random.uniform(shape=(3, 2, 5, 8))
|
||||
w = mx.random.uniform(shape=(3, 2, 3, 4))
|
||||
|
||||
def gconv(x, w):
|
||||
return mx.conv1d(x, w, groups=2)
|
||||
|
||||
expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)])
|
||||
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user