mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
conv vmap (#2102)
This commit is contained in:
@@ -1275,6 +1275,61 @@ std::vector<array> Convolution::vjp(
|
||||
return grads;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto do_conv = [&](const array& in, const array& w, int groups) {
|
||||
return conv_general(
|
||||
in,
|
||||
w,
|
||||
kernel_strides_,
|
||||
padding_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups,
|
||||
flip_,
|
||||
stream());
|
||||
};
|
||||
bool in_vmap = axes[0] >= 0;
|
||||
bool w_vmap = axes[1] >= 0;
|
||||
auto in = inputs[0];
|
||||
auto w = inputs[1];
|
||||
if (in_vmap && !w_vmap) {
|
||||
// flatten / unflatten the batch dimension
|
||||
// of the input / output
|
||||
if (axes[0] > 0) {
|
||||
in = moveaxis(in, axes[0], 0, stream());
|
||||
}
|
||||
auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_);
|
||||
out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream());
|
||||
return {{out}, {0}};
|
||||
} else if (!in_vmap && w_vmap) {
|
||||
// flatten into the output channels of w
|
||||
// unflatten the channels of the output
|
||||
if (axes[1] > 0) {
|
||||
w = moveaxis(w, axes[1], 0, stream());
|
||||
}
|
||||
auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_);
|
||||
out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream());
|
||||
return {{out}, {static_cast<int>(out.ndim() - 2)}};
|
||||
} else if (in_vmap && w_vmap) {
|
||||
// use a group convolution when both inputs are vmapped
|
||||
auto b = in.shape(axes[0]);
|
||||
in = moveaxis(in, axes[0], -2, stream());
|
||||
in = flatten(in, -2, -1, stream());
|
||||
if (axes[1] > 0) {
|
||||
w = moveaxis(w, axes[1], 0, stream());
|
||||
}
|
||||
auto c_out = w.shape(1);
|
||||
w = flatten(w, 0, 1, stream());
|
||||
auto out = do_conv(in, w, groups_ * b);
|
||||
out = unflatten(out, -1, {b, c_out}, stream());
|
||||
return {{out}, {static_cast<int>(out.ndim() - 2)}};
|
||||
} else {
|
||||
return {{do_conv(in, w, groups_)}, {-1}};
|
||||
}
|
||||
}
|
||||
|
||||
bool Convolution::is_equivalent(const Primitive& other) const {
|
||||
const Convolution& c_other = static_cast<const Convolution&>(other);
|
||||
return padding_ == c_other.padding_ &&
|
||||
|
||||
@@ -711,6 +711,7 @@ class Convolution : public UnaryPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Convolution)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
|
||||
Reference in New Issue
Block a user