diff --git a/python/mlx/nn/layers/convolution.py b/python/mlx/nn/layers/convolution.py index f825307c0..88b97add0 100644 --- a/python/mlx/nn/layers/convolution.py +++ b/python/mlx/nn/layers/convolution.py @@ -147,7 +147,7 @@ class Conv2d(Module): def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " - f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " + f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " f"groups={self.groups}, " f"bias={'bias' in self}" @@ -220,7 +220,7 @@ class Conv3d(Module): def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " - f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " + f"kernel_size={self.weight.shape[1:4]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " f"bias={'bias' in self}" )