implement Convolution::output_shape (#2601)

- pull conv_out_shape out for re-use
- add Conv::output_shape
- add e2e python tests confirming shapeless=True support and correctness

Updates #2599
This commit is contained in:
Josh Bleecher Snyder
2025-09-22 10:09:45 -07:00
committed by GitHub
parent ec2ab42888
commit aa9d44b3d4
4 changed files with 230 additions and 105 deletions

View File

@@ -3580,110 +3580,6 @@ array logcumsumexp(
namespace {
// Conv helpers
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
return ((in_dim + padding - wt_dim) / stride) + 1;
}
// Conv helpers
inline int dilate_size(int dim, int dil) {
return 1 + dil * (dim - 1);
}
Shape conv_out_shape(
const Shape& in_shape,
const Shape& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
int N = in_shape[0];
int O = wt_shape[0];
Shape out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;
int spatial_dims = in_shape.size() - 2;
if (strides.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
<< "D convolution.";
throw std::invalid_argument(msg.str());
}
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (kernel_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (input_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid input dilation " << input_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
for (; i < in_shape.size() - 1; i++) {
if (kernel_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Kernel dilation sizes must be positive."
<< " Got kernel dilation " << kernel_dilation << ".";
throw std::invalid_argument(msg.str());
}
if (input_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Input dilation sizes must be positive."
<< " Got input dilation " << input_dilation << ".";
throw std::invalid_argument(msg.str());
}
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
<< pads_lo << " | " << pads_hi << ".";
throw std::invalid_argument(msg.str());
}
if (strides[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Stride sizes must be positive." << " Got strides "
<< strides << ".";
throw std::invalid_argument(msg.str());
}
int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
int id = dilate_size(in_shape[i], input_dilation[i - 1]);
out_shape[i] = conv_out_axis_size(
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);
if (out_shape[i] <= 0) {
std::ostringstream msg;
msg << "[conv] Spatial dimensions of input after padding"
<< " cannot be smaller than weight spatial dimensions."
<< " Got error at axis " << i << " for input with shape " << in_shape
<< ", padding low " << pads_lo << ", padding high " << pads_hi
<< ", and weight of shape " << wt_shape << ".";
throw std::invalid_argument(msg.str());
}
}
out_shape[i] = O;
return out_shape;
}
inline void
run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
if (!issubdtype(in.dtype(), floating)) {
@@ -3997,7 +3893,7 @@ array conv_general(
}
// Get output shapes
auto out_shape = conv_out_shape(
auto out_shape = Convolution::conv_out_shape(
in.shape(),
wt.shape(),
stride,

View File

@@ -1243,6 +1243,114 @@ array conv_weight_backward_patches(
return grad;
}
namespace {
// Conv helpers
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
return ((in_dim + padding - wt_dim) / stride) + 1;
}
// Conv helpers
inline int dilate_size(int dim, int dil) {
return 1 + dil * (dim - 1);
}
} // namespace
Shape Convolution::conv_out_shape(
const Shape& in_shape,
const Shape& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
int N = in_shape[0];
int O = wt_shape[0];
Shape out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;
int spatial_dims = in_shape.size() - 2;
if (strides.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
<< "D convolution.";
throw std::invalid_argument(msg.str());
}
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (kernel_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (input_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid input dilation " << input_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
for (; i < in_shape.size() - 1; i++) {
if (kernel_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Kernel dilation sizes must be positive."
<< " Got kernel dilation " << kernel_dilation << ".";
throw std::invalid_argument(msg.str());
}
if (input_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Input dilation sizes must be positive."
<< " Got input dilation " << input_dilation << ".";
throw std::invalid_argument(msg.str());
}
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
<< pads_lo << " | " << pads_hi << ".";
throw std::invalid_argument(msg.str());
}
if (strides[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Stride sizes must be positive." << " Got strides "
<< strides << ".";
throw std::invalid_argument(msg.str());
}
int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
int id = dilate_size(in_shape[i], input_dilation[i - 1]);
out_shape[i] = conv_out_axis_size(
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);
if (out_shape[i] <= 0) {
std::ostringstream msg;
msg << "[conv] Spatial dimensions of input after padding"
<< " cannot be smaller than weight spatial dimensions."
<< " Got error at axis " << i << " for input with shape " << in_shape
<< ", padding low " << pads_lo << ", padding high " << pads_hi
<< ", and weight of shape " << wt_shape << ".";
throw std::invalid_argument(msg.str());
}
}
out_shape[i] = O;
return out_shape;
}
std::vector<array> Convolution::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
@@ -1454,6 +1562,18 @@ bool Convolution::is_equivalent(const Primitive& other) const {
groups_ == c_other.groups_ && flip_ == c_other.flip_;
}
std::vector<Shape> Convolution::output_shapes(
const std::vector<array>& inputs) {
return {conv_out_shape(
inputs[0].shape(), // in_shape
inputs[1].shape(), // wt_shape
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_)};
}
std::vector<array> Copy::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@@ -750,6 +750,7 @@ class Convolution : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_NAME(Convolution)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(
kernel_strides_,
@@ -761,6 +762,15 @@ class Convolution : public UnaryPrimitive {
flip_);
}
static Shape conv_out_shape(
const Shape& in_shape,
const Shape& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation);
private:
std::vector<int> padding_lo_;
std::vector<int> padding_hi_;

View File

@@ -346,6 +346,105 @@ class TestExportImport(mlx_tests.MLXTestCase):
expected = forward(input_data)
self.assertTrue(mx.allclose(expected, out))
def test_export_conv_shapeless(self):
# Conv1d (NLC)
path = os.path.join(self.test_dir, "conv1d.mlxfn")
class M1(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv1d(3, 8, kernel_size=3, stride=2, padding=1, bias=False)
def __call__(self, x):
return self.c(x)
m1 = M1()
mx.eval(m1.parameters())
def f1(x):
return m1(x)
x = mx.random.normal(shape=(4, 64, 3))
mx.export_function(path, f1, x, shapeless=True)
f1_imp = mx.import_function(path)
for shape in [(4, 64, 3), (1, 33, 3), (2, 128, 3)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(f1_imp(xt)[0], f1(xt)))
# Conv2d (NHWC)
path = os.path.join(self.test_dir, "conv2d.mlxfn")
class M2(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False)
def __call__(self, x):
return self.c(x)
m2 = M2()
mx.eval(m2.parameters())
def f2(x):
return m2(x)
x = mx.random.normal(shape=(2, 32, 32, 3))
mx.export_function(path, f2, x, shapeless=True)
f2_imp = mx.import_function(path)
for shape in [(2, 32, 32, 3), (1, 31, 31, 3), (4, 64, 48, 3)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(f2_imp(xt)[0], f2(xt)))
# Conv3d (NDHWC)
path = os.path.join(self.test_dir, "conv3d.mlxfn")
class M3(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv3d(2, 4, kernel_size=3, stride=2, padding=1, bias=False)
def __call__(self, x):
return self.c(x)
m3 = M3()
mx.eval(m3.parameters())
def f3(x):
return m3(x)
x = mx.random.normal(shape=(1, 8, 8, 8, 2))
mx.export_function(path, f3, x, shapeless=True)
f3_imp = mx.import_function(path)
for shape in [(1, 8, 8, 8, 2), (2, 7, 8, 9, 2), (1, 16, 16, 4, 2)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(f3_imp(xt)[0], f3(xt)))
# Grouped Conv2d (NHWC)
path = os.path.join(self.test_dir, "conv2d_grouped.mlxfn")
class MG(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv2d(
4, 6, kernel_size=3, stride=2, padding=1, groups=2, bias=False
)
def __call__(self, x):
return self.c(x)
mg = MG()
mx.eval(mg.parameters())
def fg(x):
return mg(x)
x = mx.random.normal(shape=(2, 32, 32, 4))
mx.export_function(path, fg, x, shapeless=True)
fg_imp = mx.import_function(path)
for shape in [(2, 32, 32, 4), (1, 32, 32, 4), (3, 15, 20, 4)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(fg_imp(xt)[0], fg(xt)))
def test_export_control_flow(self):
def fun(x, y):