diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f24aab52a2..cd972502aa 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1402,10 +1402,16 @@ array isnan(const array& a, StreamOrDevice s /* = {} */) { } array isinf(const array& a, StreamOrDevice s /* = {} */) { + if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { + return full(a.shape(), false, bool_, s); + } return logical_or(isposinf(a, s), isneginf(a, s), s); } array isfinite(const array& a, StreamOrDevice s /* = {} */) { + if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { + return full(a.shape(), true, bool_, s); + } return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s); } @@ -1497,10 +1503,17 @@ array isclose( auto out = less_equal(lhs, rhs, s); // Correct the result for infinite values. - auto any_inf = logical_or(isinf(a, s), isinf(b, s), s); + auto a_pos_inf = isposinf(a, s); + auto b_pos_inf = isposinf(b, s); + auto a_neg_inf = isneginf(a, s); + auto b_neg_inf = isneginf(b, s); + auto any_inf = logical_or( + logical_or(a_pos_inf, a_neg_inf, s), + logical_or(b_pos_inf, b_neg_inf, s), + s); auto both_inf = logical_or( - logical_and(isposinf(a, s), isposinf(b, s), s), - logical_and(isneginf(a, s), isneginf(b, s), s), + logical_and(a_pos_inf, b_pos_inf, s), + logical_and(a_neg_inf, b_neg_inf, s), s); // Convert all elements where either value is infinite to False. diff --git a/python/mlx/nn/layers/convolution.py b/python/mlx/nn/layers/convolution.py index d64ddc2451..f29341044f 100644 --- a/python/mlx/nn/layers/convolution.py +++ b/python/mlx/nn/layers/convolution.py @@ -101,6 +101,8 @@ class Conv2d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + groups (int, optional): The number of groups for the convolution. + Default: ``1``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -113,10 +115,17 @@ class Conv2d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + groups: int = 1, bias: bool = True, ): super().__init__() + if in_channels % groups != 0: + raise ValueError( + f"The number of input channels ({in_channels}) must be " + f"divisible by the number of groups ({groups})" + ) + kernel_size, stride, padding = map( lambda x: (x, x) if isinstance(x, int) else x, (kernel_size, stride, padding), @@ -125,7 +134,7 @@ class Conv2d(Module): self.weight = mx.random.uniform( low=-scale, high=scale, - shape=(out_channels, *kernel_size, in_channels), + shape=(out_channels, *kernel_size, in_channels // groups), ) if bias: self.bias = mx.zeros((out_channels,)) @@ -133,17 +142,21 @@ class Conv2d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.groups = groups def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"groups={self.groups}, " f"bias={'bias' in self}" ) def __call__(self, x): - y = mx.conv2d(x, self.weight, self.stride, self.padding, self.dilation) + y = mx.conv2d( + x, self.weight, self.stride, self.padding, self.dilation, self.groups + ) if "bias" in self: y = y + self.bias return y diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 4c891545f9..ad4c208ddf 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -706,6 +706,12 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(y.shape, (4, 4, 4, 8)) self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) + # 3x3 conv groups > 1 + x = mx.ones((4, 7, 7, 4)) + c = nn.Conv2d(4, 8, 3, padding=1, stride=1, groups=2) + y = c(x) + self.assertEqual(y.shape, (4, 7, 7, 8)) + def test_sequential(self): x = mx.ones((10, 2)) m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))