mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
conv vmap (#2102)
This commit is contained in:
parent
dc4eada7f0
commit
79b527f45f
@ -1275,6 +1275,61 @@ std::vector<array> Convolution::vjp(
|
|||||||
return grads;
|
return grads;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto do_conv = [&](const array& in, const array& w, int groups) {
|
||||||
|
return conv_general(
|
||||||
|
in,
|
||||||
|
w,
|
||||||
|
kernel_strides_,
|
||||||
|
padding_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_,
|
||||||
|
groups,
|
||||||
|
flip_,
|
||||||
|
stream());
|
||||||
|
};
|
||||||
|
bool in_vmap = axes[0] >= 0;
|
||||||
|
bool w_vmap = axes[1] >= 0;
|
||||||
|
auto in = inputs[0];
|
||||||
|
auto w = inputs[1];
|
||||||
|
if (in_vmap && !w_vmap) {
|
||||||
|
// flatten / unflatten the batch dimension
|
||||||
|
// of the input / output
|
||||||
|
if (axes[0] > 0) {
|
||||||
|
in = moveaxis(in, axes[0], 0, stream());
|
||||||
|
}
|
||||||
|
auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_);
|
||||||
|
out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream());
|
||||||
|
return {{out}, {0}};
|
||||||
|
} else if (!in_vmap && w_vmap) {
|
||||||
|
// flatten into the output channels of w
|
||||||
|
// unflatten the channels of the output
|
||||||
|
if (axes[1] > 0) {
|
||||||
|
w = moveaxis(w, axes[1], 0, stream());
|
||||||
|
}
|
||||||
|
auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_);
|
||||||
|
out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream());
|
||||||
|
return {{out}, {static_cast<int>(out.ndim() - 2)}};
|
||||||
|
} else if (in_vmap && w_vmap) {
|
||||||
|
// use a group convolution when both inputs are vmapped
|
||||||
|
auto b = in.shape(axes[0]);
|
||||||
|
in = moveaxis(in, axes[0], -2, stream());
|
||||||
|
in = flatten(in, -2, -1, stream());
|
||||||
|
if (axes[1] > 0) {
|
||||||
|
w = moveaxis(w, axes[1], 0, stream());
|
||||||
|
}
|
||||||
|
auto c_out = w.shape(1);
|
||||||
|
w = flatten(w, 0, 1, stream());
|
||||||
|
auto out = do_conv(in, w, groups_ * b);
|
||||||
|
out = unflatten(out, -1, {b, c_out}, stream());
|
||||||
|
return {{out}, {static_cast<int>(out.ndim() - 2)}};
|
||||||
|
} else {
|
||||||
|
return {{do_conv(in, w, groups_)}, {-1}};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool Convolution::is_equivalent(const Primitive& other) const {
|
bool Convolution::is_equivalent(const Primitive& other) const {
|
||||||
const Convolution& c_other = static_cast<const Convolution&>(other);
|
const Convolution& c_other = static_cast<const Convolution&>(other);
|
||||||
return padding_ == c_other.padding_ &&
|
return padding_ == c_other.padding_ &&
|
||||||
|
@ -711,6 +711,7 @@ class Convolution : public UnaryPrimitive {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(Convolution)
|
DEFINE_PRINT(Convolution)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
|
@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
|
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
|
||||||
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
|
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
|
||||||
|
|
||||||
|
def test_vmap_conv(self):
|
||||||
|
# vmap input only
|
||||||
|
x = mx.random.uniform(shape=(2, 2, 5, 4))
|
||||||
|
w = mx.random.uniform(shape=(8, 3, 4))
|
||||||
|
|
||||||
|
expected = mx.stack([mx.conv1d(xi, w) for xi in x])
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
x = mx.moveaxis(x, 0, 2)
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# vmap weights only
|
||||||
|
x = mx.random.uniform(shape=(2, 5, 4))
|
||||||
|
w = mx.random.uniform(shape=(3, 8, 3, 4))
|
||||||
|
|
||||||
|
expected = mx.stack([mx.conv1d(x, wi) for wi in w])
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
w = mx.moveaxis(w, 0, 1)
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# vmap weights and input
|
||||||
|
x = mx.random.uniform(shape=(3, 2, 5, 4))
|
||||||
|
w = mx.random.uniform(shape=(3, 8, 3, 4))
|
||||||
|
|
||||||
|
expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)])
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 3, 5, 4))
|
||||||
|
w = mx.random.uniform(shape=(8, 3, 4, 3))
|
||||||
|
|
||||||
|
expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)])
|
||||||
|
out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# Test with groups
|
||||||
|
x = mx.random.uniform(shape=(3, 2, 5, 8))
|
||||||
|
w = mx.random.uniform(shape=(3, 2, 3, 4))
|
||||||
|
|
||||||
|
def gconv(x, w):
|
||||||
|
return mx.conv1d(x, w, groups=2)
|
||||||
|
|
||||||
|
expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)])
|
||||||
|
out = mx.vmap(gconv, in_axes=(0, 0))(x, w)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user