mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-26 15:58:14 +08:00
implement Convolution::output_shape (#2601)
- pull conv_out_shape out for re-use - add Conv::output_shape - add e2e python tests confirming shapeless=True support and correctness Updates #2599
This commit is contained in:

committed by
GitHub

parent
ec2ab42888
commit
aa9d44b3d4
106
mlx/ops.cpp
106
mlx/ops.cpp
@@ -3580,110 +3580,6 @@ array logcumsumexp(
|
||||
|
||||
namespace {
|
||||
|
||||
// Conv helpers
|
||||
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
|
||||
return ((in_dim + padding - wt_dim) / stride) + 1;
|
||||
}
|
||||
|
||||
// Conv helpers
|
||||
inline int dilate_size(int dim, int dil) {
|
||||
return 1 + dil * (dim - 1);
|
||||
}
|
||||
|
||||
Shape conv_out_shape(
|
||||
const Shape& in_shape,
|
||||
const Shape& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& pads_lo,
|
||||
const std::vector<int>& pads_hi,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation) {
|
||||
int N = in_shape[0];
|
||||
int O = wt_shape[0];
|
||||
Shape out_shape(in_shape.size());
|
||||
int i = 0;
|
||||
out_shape[i++] = N;
|
||||
|
||||
int spatial_dims = in_shape.size() - 2;
|
||||
|
||||
if (strides.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
|
||||
<< "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (kernel_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input dilation " << input_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
for (; i < in_shape.size() - 1; i++) {
|
||||
if (kernel_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Kernel dilation sizes must be positive."
|
||||
<< " Got kernel dilation " << kernel_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Input dilation sizes must be positive."
|
||||
<< " Got input dilation " << input_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
|
||||
<< pads_lo << " | " << pads_hi << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (strides[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Stride sizes must be positive." << " Got strides "
|
||||
<< strides << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
|
||||
int id = dilate_size(in_shape[i], input_dilation[i - 1]);
|
||||
|
||||
out_shape[i] = conv_out_axis_size(
|
||||
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);
|
||||
|
||||
if (out_shape[i] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Spatial dimensions of input after padding"
|
||||
<< " cannot be smaller than weight spatial dimensions."
|
||||
<< " Got error at axis " << i << " for input with shape " << in_shape
|
||||
<< ", padding low " << pads_lo << ", padding high " << pads_hi
|
||||
<< ", and weight of shape " << wt_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
out_shape[i] = O;
|
||||
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
inline void
|
||||
run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
|
||||
if (!issubdtype(in.dtype(), floating)) {
|
||||
@@ -3997,7 +3893,7 @@ array conv_general(
|
||||
}
|
||||
|
||||
// Get output shapes
|
||||
auto out_shape = conv_out_shape(
|
||||
auto out_shape = Convolution::conv_out_shape(
|
||||
in.shape(),
|
||||
wt.shape(),
|
||||
stride,
|
||||
|
@@ -1243,6 +1243,114 @@ array conv_weight_backward_patches(
|
||||
return grad;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Conv helpers
|
||||
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
|
||||
return ((in_dim + padding - wt_dim) / stride) + 1;
|
||||
}
|
||||
|
||||
// Conv helpers
|
||||
inline int dilate_size(int dim, int dil) {
|
||||
return 1 + dil * (dim - 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Shape Convolution::conv_out_shape(
|
||||
const Shape& in_shape,
|
||||
const Shape& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& pads_lo,
|
||||
const std::vector<int>& pads_hi,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation) {
|
||||
int N = in_shape[0];
|
||||
int O = wt_shape[0];
|
||||
Shape out_shape(in_shape.size());
|
||||
int i = 0;
|
||||
out_shape[i++] = N;
|
||||
|
||||
int spatial_dims = in_shape.size() - 2;
|
||||
|
||||
if (strides.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
|
||||
<< "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (kernel_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input dilation " << input_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
for (; i < in_shape.size() - 1; i++) {
|
||||
if (kernel_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Kernel dilation sizes must be positive."
|
||||
<< " Got kernel dilation " << kernel_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Input dilation sizes must be positive."
|
||||
<< " Got input dilation " << input_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
|
||||
<< pads_lo << " | " << pads_hi << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (strides[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Stride sizes must be positive." << " Got strides "
|
||||
<< strides << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
|
||||
int id = dilate_size(in_shape[i], input_dilation[i - 1]);
|
||||
|
||||
out_shape[i] = conv_out_axis_size(
|
||||
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);
|
||||
|
||||
if (out_shape[i] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Spatial dimensions of input after padding"
|
||||
<< " cannot be smaller than weight spatial dimensions."
|
||||
<< " Got error at axis " << i << " for input with shape " << in_shape
|
||||
<< ", padding low " << pads_lo << ", padding high " << pads_hi
|
||||
<< ", and weight of shape " << wt_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
out_shape[i] = O;
|
||||
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
std::vector<array> Convolution::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@@ -1454,6 +1562,18 @@ bool Convolution::is_equivalent(const Primitive& other) const {
|
||||
groups_ == c_other.groups_ && flip_ == c_other.flip_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Convolution::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
return {conv_out_shape(
|
||||
inputs[0].shape(), // in_shape
|
||||
inputs[1].shape(), // wt_shape
|
||||
kernel_strides_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_dilation_,
|
||||
input_dilation_)};
|
||||
}
|
||||
|
||||
std::vector<array> Copy::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@@ -750,6 +750,7 @@ class Convolution : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_NAME(Convolution)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
kernel_strides_,
|
||||
@@ -761,6 +762,15 @@ class Convolution : public UnaryPrimitive {
|
||||
flip_);
|
||||
}
|
||||
|
||||
static Shape conv_out_shape(
|
||||
const Shape& in_shape,
|
||||
const Shape& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& pads_lo,
|
||||
const std::vector<int>& pads_hi,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation);
|
||||
|
||||
private:
|
||||
std::vector<int> padding_lo_;
|
||||
std::vector<int> padding_hi_;
|
||||
|
@@ -346,6 +346,105 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
expected = forward(input_data)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_export_conv_shapeless(self):
|
||||
# Conv1d (NLC)
|
||||
path = os.path.join(self.test_dir, "conv1d.mlxfn")
|
||||
|
||||
class M1(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = nn.Conv1d(3, 8, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c(x)
|
||||
|
||||
m1 = M1()
|
||||
mx.eval(m1.parameters())
|
||||
|
||||
def f1(x):
|
||||
return m1(x)
|
||||
|
||||
x = mx.random.normal(shape=(4, 64, 3))
|
||||
mx.export_function(path, f1, x, shapeless=True)
|
||||
f1_imp = mx.import_function(path)
|
||||
for shape in [(4, 64, 3), (1, 33, 3), (2, 128, 3)]:
|
||||
xt = mx.random.normal(shape=shape)
|
||||
self.assertTrue(mx.allclose(f1_imp(xt)[0], f1(xt)))
|
||||
|
||||
# Conv2d (NHWC)
|
||||
path = os.path.join(self.test_dir, "conv2d.mlxfn")
|
||||
|
||||
class M2(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c(x)
|
||||
|
||||
m2 = M2()
|
||||
mx.eval(m2.parameters())
|
||||
|
||||
def f2(x):
|
||||
return m2(x)
|
||||
|
||||
x = mx.random.normal(shape=(2, 32, 32, 3))
|
||||
mx.export_function(path, f2, x, shapeless=True)
|
||||
f2_imp = mx.import_function(path)
|
||||
for shape in [(2, 32, 32, 3), (1, 31, 31, 3), (4, 64, 48, 3)]:
|
||||
xt = mx.random.normal(shape=shape)
|
||||
self.assertTrue(mx.allclose(f2_imp(xt)[0], f2(xt)))
|
||||
|
||||
# Conv3d (NDHWC)
|
||||
path = os.path.join(self.test_dir, "conv3d.mlxfn")
|
||||
|
||||
class M3(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = nn.Conv3d(2, 4, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c(x)
|
||||
|
||||
m3 = M3()
|
||||
mx.eval(m3.parameters())
|
||||
|
||||
def f3(x):
|
||||
return m3(x)
|
||||
|
||||
x = mx.random.normal(shape=(1, 8, 8, 8, 2))
|
||||
mx.export_function(path, f3, x, shapeless=True)
|
||||
f3_imp = mx.import_function(path)
|
||||
for shape in [(1, 8, 8, 8, 2), (2, 7, 8, 9, 2), (1, 16, 16, 4, 2)]:
|
||||
xt = mx.random.normal(shape=shape)
|
||||
self.assertTrue(mx.allclose(f3_imp(xt)[0], f3(xt)))
|
||||
|
||||
# Grouped Conv2d (NHWC)
|
||||
path = os.path.join(self.test_dir, "conv2d_grouped.mlxfn")
|
||||
|
||||
class MG(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = nn.Conv2d(
|
||||
4, 6, kernel_size=3, stride=2, padding=1, groups=2, bias=False
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c(x)
|
||||
|
||||
mg = MG()
|
||||
mx.eval(mg.parameters())
|
||||
|
||||
def fg(x):
|
||||
return mg(x)
|
||||
|
||||
x = mx.random.normal(shape=(2, 32, 32, 4))
|
||||
mx.export_function(path, fg, x, shapeless=True)
|
||||
fg_imp = mx.import_function(path)
|
||||
for shape in [(2, 32, 32, 4), (1, 32, 32, 4), (3, 15, 20, 4)]:
|
||||
xt = mx.random.normal(shape=shape)
|
||||
self.assertTrue(mx.allclose(fg_imp(xt)[0], fg(xt)))
|
||||
|
||||
def test_export_control_flow(self):
|
||||
|
||||
def fun(x, y):
|
||||
|
Reference in New Issue
Block a user