diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a2271c4fd..c1d16ba1f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3580,110 +3580,6 @@ array logcumsumexp( namespace { -// Conv helpers -inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) { - return ((in_dim + padding - wt_dim) / stride) + 1; -} - -// Conv helpers -inline int dilate_size(int dim, int dil) { - return 1 + dil * (dim - 1); -} - -Shape conv_out_shape( - const Shape& in_shape, - const Shape& wt_shape, - const std::vector& strides, - const std::vector& pads_lo, - const std::vector& pads_hi, - const std::vector& kernel_dilation, - const std::vector& input_dilation) { - int N = in_shape[0]; - int O = wt_shape[0]; - Shape out_shape(in_shape.size()); - int i = 0; - out_shape[i++] = N; - - int spatial_dims = in_shape.size() - 2; - - if (strides.size() != spatial_dims) { - std::ostringstream msg; - msg << "[conv] Invalid strides " << strides << " for " << spatial_dims - << "D convolution."; - throw std::invalid_argument(msg.str()); - } - - if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) { - std::ostringstream msg; - msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for " - << spatial_dims << "D convolution."; - throw std::invalid_argument(msg.str()); - } - - if (kernel_dilation.size() != spatial_dims) { - std::ostringstream msg; - msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for " - << spatial_dims << "D convolution."; - throw std::invalid_argument(msg.str()); - } - - if (input_dilation.size() != spatial_dims) { - std::ostringstream msg; - msg << "[conv] Invalid input dilation " << input_dilation << " for " - << spatial_dims << "D convolution."; - throw std::invalid_argument(msg.str()); - } - - for (; i < in_shape.size() - 1; i++) { - if (kernel_dilation[i - 1] <= 0) { - std::ostringstream msg; - msg << "[conv] Kernel dilation sizes must be positive." - << " Got kernel dilation " << kernel_dilation << "."; - throw std::invalid_argument(msg.str()); - } - - if (input_dilation[i - 1] <= 0) { - std::ostringstream msg; - msg << "[conv] Input dilation sizes must be positive." - << " Got input dilation " << input_dilation << "."; - throw std::invalid_argument(msg.str()); - } - - if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) { - std::ostringstream msg; - msg << "[conv] Padding sizes must be non-negative." << " Got padding " - << pads_lo << " | " << pads_hi << "."; - throw std::invalid_argument(msg.str()); - } - - if (strides[i - 1] <= 0) { - std::ostringstream msg; - msg << "[conv] Stride sizes must be positive." << " Got strides " - << strides << "."; - throw std::invalid_argument(msg.str()); - } - - int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]); - int id = dilate_size(in_shape[i], input_dilation[i - 1]); - - out_shape[i] = conv_out_axis_size( - id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]); - - if (out_shape[i] <= 0) { - std::ostringstream msg; - msg << "[conv] Spatial dimensions of input after padding" - << " cannot be smaller than weight spatial dimensions." - << " Got error at axis " << i << " for input with shape " << in_shape - << ", padding low " << pads_lo << ", padding high " << pads_hi - << ", and weight of shape " << wt_shape << "."; - throw std::invalid_argument(msg.str()); - } - } - out_shape[i] = O; - - return out_shape; -} - inline void run_conv_checks(const array& in, const array& wt, int n_dim, int groups) { if (!issubdtype(in.dtype(), floating)) { @@ -3997,7 +3893,7 @@ array conv_general( } // Get output shapes - auto out_shape = conv_out_shape( + auto out_shape = Convolution::conv_out_shape( in.shape(), wt.shape(), stride, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 1406fd46f..655a55910 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1243,6 +1243,114 @@ array conv_weight_backward_patches( return grad; } +namespace { + +// Conv helpers +inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) { + return ((in_dim + padding - wt_dim) / stride) + 1; +} + +// Conv helpers +inline int dilate_size(int dim, int dil) { + return 1 + dil * (dim - 1); +} + +} // namespace + +Shape Convolution::conv_out_shape( + const Shape& in_shape, + const Shape& wt_shape, + const std::vector& strides, + const std::vector& pads_lo, + const std::vector& pads_hi, + const std::vector& kernel_dilation, + const std::vector& input_dilation) { + int N = in_shape[0]; + int O = wt_shape[0]; + Shape out_shape(in_shape.size()); + int i = 0; + out_shape[i++] = N; + + int spatial_dims = in_shape.size() - 2; + + if (strides.size() != spatial_dims) { + std::ostringstream msg; + msg << "[conv] Invalid strides " << strides << " for " << spatial_dims + << "D convolution."; + throw std::invalid_argument(msg.str()); + } + + if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) { + std::ostringstream msg; + msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for " + << spatial_dims << "D convolution."; + throw std::invalid_argument(msg.str()); + } + + if (kernel_dilation.size() != spatial_dims) { + std::ostringstream msg; + msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for " + << spatial_dims << "D convolution."; + throw std::invalid_argument(msg.str()); + } + + if (input_dilation.size() != spatial_dims) { + std::ostringstream msg; + msg << "[conv] Invalid input dilation " << input_dilation << " for " + << spatial_dims << "D convolution."; + throw std::invalid_argument(msg.str()); + } + + for (; i < in_shape.size() - 1; i++) { + if (kernel_dilation[i - 1] <= 0) { + std::ostringstream msg; + msg << "[conv] Kernel dilation sizes must be positive." + << " Got kernel dilation " << kernel_dilation << "."; + throw std::invalid_argument(msg.str()); + } + + if (input_dilation[i - 1] <= 0) { + std::ostringstream msg; + msg << "[conv] Input dilation sizes must be positive." + << " Got input dilation " << input_dilation << "."; + throw std::invalid_argument(msg.str()); + } + + if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) { + std::ostringstream msg; + msg << "[conv] Padding sizes must be non-negative." << " Got padding " + << pads_lo << " | " << pads_hi << "."; + throw std::invalid_argument(msg.str()); + } + + if (strides[i - 1] <= 0) { + std::ostringstream msg; + msg << "[conv] Stride sizes must be positive." << " Got strides " + << strides << "."; + throw std::invalid_argument(msg.str()); + } + + int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]); + int id = dilate_size(in_shape[i], input_dilation[i - 1]); + + out_shape[i] = conv_out_axis_size( + id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]); + + if (out_shape[i] <= 0) { + std::ostringstream msg; + msg << "[conv] Spatial dimensions of input after padding" + << " cannot be smaller than weight spatial dimensions." + << " Got error at axis " << i << " for input with shape " << in_shape + << ", padding low " << pads_lo << ", padding high " << pads_hi + << ", and weight of shape " << wt_shape << "."; + throw std::invalid_argument(msg.str()); + } + } + out_shape[i] = O; + + return out_shape; +} + std::vector Convolution::vjp( const std::vector& primals, const std::vector& cotangents, @@ -1454,6 +1562,18 @@ bool Convolution::is_equivalent(const Primitive& other) const { groups_ == c_other.groups_ && flip_ == c_other.flip_; } +std::vector Convolution::output_shapes( + const std::vector& inputs) { + return {conv_out_shape( + inputs[0].shape(), // in_shape + inputs[1].shape(), // wt_shape + kernel_strides_, + padding_lo_, + padding_hi_, + kernel_dilation_, + input_dilation_)}; +} + std::vector Copy::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 986675f3a..2a843a0e4 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -750,6 +750,7 @@ class Convolution : public UnaryPrimitive { DEFINE_VMAP() DEFINE_NAME(Convolution) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; auto state() const { return std::make_tuple( kernel_strides_, @@ -761,6 +762,15 @@ class Convolution : public UnaryPrimitive { flip_); } + static Shape conv_out_shape( + const Shape& in_shape, + const Shape& wt_shape, + const std::vector& strides, + const std::vector& pads_lo, + const std::vector& pads_hi, + const std::vector& kernel_dilation, + const std::vector& input_dilation); + private: std::vector padding_lo_; std::vector padding_hi_; diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 71fb5b27f..1d8af8509 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -346,6 +346,105 @@ class TestExportImport(mlx_tests.MLXTestCase): expected = forward(input_data) self.assertTrue(mx.allclose(expected, out)) + def test_export_conv_shapeless(self): + # Conv1d (NLC) + path = os.path.join(self.test_dir, "conv1d.mlxfn") + + class M1(nn.Module): + def __init__(self): + super().__init__() + self.c = nn.Conv1d(3, 8, kernel_size=3, stride=2, padding=1, bias=False) + + def __call__(self, x): + return self.c(x) + + m1 = M1() + mx.eval(m1.parameters()) + + def f1(x): + return m1(x) + + x = mx.random.normal(shape=(4, 64, 3)) + mx.export_function(path, f1, x, shapeless=True) + f1_imp = mx.import_function(path) + for shape in [(4, 64, 3), (1, 33, 3), (2, 128, 3)]: + xt = mx.random.normal(shape=shape) + self.assertTrue(mx.allclose(f1_imp(xt)[0], f1(xt))) + + # Conv2d (NHWC) + path = os.path.join(self.test_dir, "conv2d.mlxfn") + + class M2(nn.Module): + def __init__(self): + super().__init__() + self.c = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False) + + def __call__(self, x): + return self.c(x) + + m2 = M2() + mx.eval(m2.parameters()) + + def f2(x): + return m2(x) + + x = mx.random.normal(shape=(2, 32, 32, 3)) + mx.export_function(path, f2, x, shapeless=True) + f2_imp = mx.import_function(path) + for shape in [(2, 32, 32, 3), (1, 31, 31, 3), (4, 64, 48, 3)]: + xt = mx.random.normal(shape=shape) + self.assertTrue(mx.allclose(f2_imp(xt)[0], f2(xt))) + + # Conv3d (NDHWC) + path = os.path.join(self.test_dir, "conv3d.mlxfn") + + class M3(nn.Module): + def __init__(self): + super().__init__() + self.c = nn.Conv3d(2, 4, kernel_size=3, stride=2, padding=1, bias=False) + + def __call__(self, x): + return self.c(x) + + m3 = M3() + mx.eval(m3.parameters()) + + def f3(x): + return m3(x) + + x = mx.random.normal(shape=(1, 8, 8, 8, 2)) + mx.export_function(path, f3, x, shapeless=True) + f3_imp = mx.import_function(path) + for shape in [(1, 8, 8, 8, 2), (2, 7, 8, 9, 2), (1, 16, 16, 4, 2)]: + xt = mx.random.normal(shape=shape) + self.assertTrue(mx.allclose(f3_imp(xt)[0], f3(xt))) + + # Grouped Conv2d (NHWC) + path = os.path.join(self.test_dir, "conv2d_grouped.mlxfn") + + class MG(nn.Module): + def __init__(self): + super().__init__() + self.c = nn.Conv2d( + 4, 6, kernel_size=3, stride=2, padding=1, groups=2, bias=False + ) + + def __call__(self, x): + return self.c(x) + + mg = MG() + mx.eval(mg.parameters()) + + def fg(x): + return mg(x) + + x = mx.random.normal(shape=(2, 32, 32, 4)) + mx.export_function(path, fg, x, shapeless=True) + fg_imp = mx.import_function(path) + for shape in [(2, 32, 32, 4), (1, 32, 32, 4), (3, 15, 20, 4)]: + xt = mx.random.normal(shape=shape) + self.assertTrue(mx.allclose(fg_imp(xt)[0], fg(xt))) + def test_export_control_flow(self): def fun(x, y):