From 7df3f792a2370dd021a7fdfad887add8883eb3bd Mon Sep 17 00:00:00 2001 From: Franck Verrot Date: Mon, 10 Feb 2025 06:27:01 -0800 Subject: [PATCH] Ensure Conv2D and Conv3D's kernel sizes aren't trimmed (#1852) Before the change, this snippet: ``` print(nn.Conv1d(1, 32, 3, padding=1)) print(nn.Conv2d(1, 32, (3, 3), padding=1)) print(nn.Conv3d(1, 32, (3, 3, 3), padding=1)) ``` would output: ``` Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True) Conv2d(1, 32, kernel_size=(3,), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True) Conv3d(1, 32, kernel_size=(3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True) ``` After the change, the output will be: ``` Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True) Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True) Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True) ``` --- python/mlx/nn/layers/convolution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlx/nn/layers/convolution.py b/python/mlx/nn/layers/convolution.py index f825307c0..88b97add0 100644 --- a/python/mlx/nn/layers/convolution.py +++ b/python/mlx/nn/layers/convolution.py @@ -147,7 +147,7 @@ class Conv2d(Module): def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " - f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " + f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " f"groups={self.groups}, " f"bias={'bias' in self}" @@ -220,7 +220,7 @@ class Conv3d(Module): def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " - f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " + f"kernel_size={self.weight.shape[1:4]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " f"bias={'bias' in self}" )