mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
implement Convolution::output_shape (#2601)
- pull conv_out_shape out for re-use - add Conv::output_shape - add e2e python tests confirming shapeless=True support and correctness Updates #2599
This commit is contained in:
committed by
GitHub
parent
ec2ab42888
commit
aa9d44b3d4
106
mlx/ops.cpp
106
mlx/ops.cpp
@@ -3580,110 +3580,6 @@ array logcumsumexp(
|
||||
|
||||
namespace {
|
||||
|
||||
// Conv helpers
|
||||
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
|
||||
return ((in_dim + padding - wt_dim) / stride) + 1;
|
||||
}
|
||||
|
||||
// Conv helpers
|
||||
inline int dilate_size(int dim, int dil) {
|
||||
return 1 + dil * (dim - 1);
|
||||
}
|
||||
|
||||
Shape conv_out_shape(
|
||||
const Shape& in_shape,
|
||||
const Shape& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& pads_lo,
|
||||
const std::vector<int>& pads_hi,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation) {
|
||||
int N = in_shape[0];
|
||||
int O = wt_shape[0];
|
||||
Shape out_shape(in_shape.size());
|
||||
int i = 0;
|
||||
out_shape[i++] = N;
|
||||
|
||||
int spatial_dims = in_shape.size() - 2;
|
||||
|
||||
if (strides.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
|
||||
<< "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (kernel_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input dilation " << input_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
for (; i < in_shape.size() - 1; i++) {
|
||||
if (kernel_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Kernel dilation sizes must be positive."
|
||||
<< " Got kernel dilation " << kernel_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Input dilation sizes must be positive."
|
||||
<< " Got input dilation " << input_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
|
||||
<< pads_lo << " | " << pads_hi << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (strides[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Stride sizes must be positive." << " Got strides "
|
||||
<< strides << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
|
||||
int id = dilate_size(in_shape[i], input_dilation[i - 1]);
|
||||
|
||||
out_shape[i] = conv_out_axis_size(
|
||||
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);
|
||||
|
||||
if (out_shape[i] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Spatial dimensions of input after padding"
|
||||
<< " cannot be smaller than weight spatial dimensions."
|
||||
<< " Got error at axis " << i << " for input with shape " << in_shape
|
||||
<< ", padding low " << pads_lo << ", padding high " << pads_hi
|
||||
<< ", and weight of shape " << wt_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
out_shape[i] = O;
|
||||
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
inline void
|
||||
run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
|
||||
if (!issubdtype(in.dtype(), floating)) {
|
||||
@@ -3997,7 +3893,7 @@ array conv_general(
|
||||
}
|
||||
|
||||
// Get output shapes
|
||||
auto out_shape = conv_out_shape(
|
||||
auto out_shape = Convolution::conv_out_shape(
|
||||
in.shape(),
|
||||
wt.shape(),
|
||||
stride,
|
||||
|
||||
@@ -1243,6 +1243,114 @@ array conv_weight_backward_patches(
|
||||
return grad;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Conv helpers
|
||||
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
|
||||
return ((in_dim + padding - wt_dim) / stride) + 1;
|
||||
}
|
||||
|
||||
// Conv helpers
|
||||
inline int dilate_size(int dim, int dil) {
|
||||
return 1 + dil * (dim - 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Shape Convolution::conv_out_shape(
|
||||
const Shape& in_shape,
|
||||
const Shape& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& pads_lo,
|
||||
const std::vector<int>& pads_hi,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation) {
|
||||
int N = in_shape[0];
|
||||
int O = wt_shape[0];
|
||||
Shape out_shape(in_shape.size());
|
||||
int i = 0;
|
||||
out_shape[i++] = N;
|
||||
|
||||
int spatial_dims = in_shape.size() - 2;
|
||||
|
||||
if (strides.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
|
||||
<< "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (kernel_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input dilation " << input_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
for (; i < in_shape.size() - 1; i++) {
|
||||
if (kernel_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Kernel dilation sizes must be positive."
|
||||
<< " Got kernel dilation " << kernel_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Input dilation sizes must be positive."
|
||||
<< " Got input dilation " << input_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
|
||||
<< pads_lo << " | " << pads_hi << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (strides[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Stride sizes must be positive." << " Got strides "
|
||||
<< strides << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
|
||||
int id = dilate_size(in_shape[i], input_dilation[i - 1]);
|
||||
|
||||
out_shape[i] = conv_out_axis_size(
|
||||
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);
|
||||
|
||||
if (out_shape[i] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Spatial dimensions of input after padding"
|
||||
<< " cannot be smaller than weight spatial dimensions."
|
||||
<< " Got error at axis " << i << " for input with shape " << in_shape
|
||||
<< ", padding low " << pads_lo << ", padding high " << pads_hi
|
||||
<< ", and weight of shape " << wt_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
out_shape[i] = O;
|
||||
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
std::vector<array> Convolution::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@@ -1454,6 +1562,18 @@ bool Convolution::is_equivalent(const Primitive& other) const {
|
||||
groups_ == c_other.groups_ && flip_ == c_other.flip_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Convolution::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
return {conv_out_shape(
|
||||
inputs[0].shape(), // in_shape
|
||||
inputs[1].shape(), // wt_shape
|
||||
kernel_strides_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_dilation_,
|
||||
input_dilation_)};
|
||||
}
|
||||
|
||||
std::vector<array> Copy::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
||||
@@ -750,6 +750,7 @@ class Convolution : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_NAME(Convolution)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
kernel_strides_,
|
||||
@@ -761,6 +762,15 @@ class Convolution : public UnaryPrimitive {
|
||||
flip_);
|
||||
}
|
||||
|
||||
static Shape conv_out_shape(
|
||||
const Shape& in_shape,
|
||||
const Shape& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& pads_lo,
|
||||
const std::vector<int>& pads_hi,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation);
|
||||
|
||||
private:
|
||||
std::vector<int> padding_lo_;
|
||||
std::vector<int> padding_hi_;
|
||||
|
||||
Reference in New Issue
Block a user