mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	add groups in conv2d (#1569)
This commit is contained in:
		
							
								
								
									
										19
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							| @@ -1402,10 +1402,16 @@ array isnan(const array& a, StreamOrDevice s /* = {} */) { | ||||
| } | ||||
|  | ||||
| array isinf(const array& a, StreamOrDevice s /* = {} */) { | ||||
|   if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { | ||||
|     return full(a.shape(), false, bool_, s); | ||||
|   } | ||||
|   return logical_or(isposinf(a, s), isneginf(a, s), s); | ||||
| } | ||||
|  | ||||
| array isfinite(const array& a, StreamOrDevice s /* = {} */) { | ||||
|   if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { | ||||
|     return full(a.shape(), true, bool_, s); | ||||
|   } | ||||
|   return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s); | ||||
| } | ||||
|  | ||||
| @@ -1497,10 +1503,17 @@ array isclose( | ||||
|   auto out = less_equal(lhs, rhs, s); | ||||
|  | ||||
|   // Correct the result for infinite values. | ||||
|   auto any_inf = logical_or(isinf(a, s), isinf(b, s), s); | ||||
|   auto a_pos_inf = isposinf(a, s); | ||||
|   auto b_pos_inf = isposinf(b, s); | ||||
|   auto a_neg_inf = isneginf(a, s); | ||||
|   auto b_neg_inf = isneginf(b, s); | ||||
|   auto any_inf = logical_or( | ||||
|       logical_or(a_pos_inf, a_neg_inf, s), | ||||
|       logical_or(b_pos_inf, b_neg_inf, s), | ||||
|       s); | ||||
|   auto both_inf = logical_or( | ||||
|       logical_and(isposinf(a, s), isposinf(b, s), s), | ||||
|       logical_and(isneginf(a, s), isneginf(b, s), s), | ||||
|       logical_and(a_pos_inf, b_pos_inf, s), | ||||
|       logical_and(a_neg_inf, b_neg_inf, s), | ||||
|       s); | ||||
|  | ||||
|   // Convert all elements where either value is infinite to False. | ||||
|   | ||||
| @@ -101,6 +101,8 @@ class Conv2d(Module): | ||||
|         padding (int or tuple, optional): How many positions to 0-pad | ||||
|             the input with. Default: ``0``. | ||||
|         dilation (int or tuple, optional): The dilation of the convolution. | ||||
|         groups (int, optional): The number of groups for the convolution. | ||||
|             Default: ``1``. | ||||
|         bias (bool, optional): If ``True`` add a learnable bias to the | ||||
|             output. Default: ``True`` | ||||
|     """ | ||||
| @@ -113,10 +115,17 @@ class Conv2d(Module): | ||||
|         stride: Union[int, tuple] = 1, | ||||
|         padding: Union[int, tuple] = 0, | ||||
|         dilation: Union[int, tuple] = 1, | ||||
|         groups: int = 1, | ||||
|         bias: bool = True, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         if in_channels % groups != 0: | ||||
|             raise ValueError( | ||||
|                 f"The number of input channels ({in_channels}) must be " | ||||
|                 f"divisible by the number of groups ({groups})" | ||||
|             ) | ||||
|  | ||||
|         kernel_size, stride, padding = map( | ||||
|             lambda x: (x, x) if isinstance(x, int) else x, | ||||
|             (kernel_size, stride, padding), | ||||
| @@ -125,7 +134,7 @@ class Conv2d(Module): | ||||
|         self.weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(out_channels, *kernel_size, in_channels), | ||||
|             shape=(out_channels, *kernel_size, in_channels // groups), | ||||
|         ) | ||||
|         if bias: | ||||
|             self.bias = mx.zeros((out_channels,)) | ||||
| @@ -133,17 +142,21 @@ class Conv2d(Module): | ||||
|         self.padding = padding | ||||
|         self.stride = stride | ||||
|         self.dilation = dilation | ||||
|         self.groups = groups | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return ( | ||||
|             f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " | ||||
|             f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " | ||||
|             f"padding={self.padding}, dilation={self.dilation}, " | ||||
|             f"groups={self.groups}, " | ||||
|             f"bias={'bias' in self}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         y = mx.conv2d(x, self.weight, self.stride, self.padding, self.dilation) | ||||
|         y = mx.conv2d( | ||||
|             x, self.weight, self.stride, self.padding, self.dilation, self.groups | ||||
|         ) | ||||
|         if "bias" in self: | ||||
|             y = y + self.bias | ||||
|         return y | ||||
|   | ||||
| @@ -706,6 +706,12 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(y.shape, (4, 4, 4, 8)) | ||||
|         self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) | ||||
|  | ||||
|         # 3x3 conv groups > 1 | ||||
|         x = mx.ones((4, 7, 7, 4)) | ||||
|         c = nn.Conv2d(4, 8, 3, padding=1, stride=1, groups=2) | ||||
|         y = c(x) | ||||
|         self.assertEqual(y.shape, (4, 7, 7, 8)) | ||||
|  | ||||
|     def test_sequential(self): | ||||
|         x = mx.ones((10, 2)) | ||||
|         m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun