mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Pooling layers (#357)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		 Gabrijel Boduljak
					Gabrijel Boduljak
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							40c108766b
						
					
				
				
					commit
					e54cbb7ba6
				
			| @@ -905,6 +905,347 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_pooling(self): | ||||
|         # Test 1d pooling | ||||
|         x = mx.array( | ||||
|             [ | ||||
|                 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], | ||||
|                 [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]], | ||||
|             ] | ||||
|         ) | ||||
|         expected_max_pool_output_no_padding_stride_1 = [ | ||||
|             [[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], | ||||
|             [[15.0, 16.0, 17.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], | ||||
|         ] | ||||
|         expected_max_pool_output_no_padding_stride_2 = [ | ||||
|             [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]], | ||||
|             [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]], | ||||
|         ] | ||||
|         expected_max_pool_output_padding_1_stride_2 = [ | ||||
|             [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], | ||||
|             [[12.0, 13.0, 14.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], | ||||
|         ] | ||||
|         expected_max_pool_output_padding_1_stride_2_kernel_3 = [ | ||||
|             [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]], | ||||
|             [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]], | ||||
|         ] | ||||
|         expected_avg_pool_output_no_padding_stride_1 = [ | ||||
|             [ | ||||
|                 [1.5000, 2.5000, 3.5000], | ||||
|                 [4.5000, 5.5000, 6.5000], | ||||
|                 [7.5000, 8.5000, 9.5000], | ||||
|             ], | ||||
|             [ | ||||
|                 [13.5000, 14.5000, 15.5000], | ||||
|                 [16.5000, 17.5000, 18.5000], | ||||
|                 [19.5000, 20.5000, 21.5000], | ||||
|             ], | ||||
|         ] | ||||
|         expected_avg_pool_output_no_padding_stride_2 = [ | ||||
|             [[1.5000, 2.5000, 3.5000], [7.5000, 8.5000, 9.5000]], | ||||
|             [[13.5000, 14.5000, 15.5000], [19.5000, 20.5000, 21.5000]], | ||||
|         ] | ||||
|         expected_avg_pool_output_padding_1_stride_2 = [ | ||||
|             [ | ||||
|                 [0.0000, 0.5000, 1.0000], | ||||
|                 [4.5000, 5.5000, 6.5000], | ||||
|                 [4.5000, 5.0000, 5.5000], | ||||
|             ], | ||||
|             [ | ||||
|                 [6.0000, 6.5000, 7.0000], | ||||
|                 [16.5000, 17.5000, 18.5000], | ||||
|                 [10.5000, 11.0000, 11.5000], | ||||
|             ], | ||||
|         ] | ||||
|         expected_avg_pool_output_padding_1_kernel_3 = [ | ||||
|             [[1, 1.66667, 2.33333], [6, 7, 8]], | ||||
|             [[9, 9.66667, 10.3333], [18, 19, 20]], | ||||
|         ] | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.MaxPool1d(kernel_size=2, stride=1, padding=0)(x), | ||||
|                 expected_max_pool_output_no_padding_stride_1, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.MaxPool1d(kernel_size=2, stride=2, padding=0)(x), | ||||
|                 expected_max_pool_output_no_padding_stride_2, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.MaxPool1d(kernel_size=2, stride=2, padding=1)(x), | ||||
|                 expected_max_pool_output_padding_1_stride_2, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.MaxPool1d(kernel_size=3, stride=2, padding=1)(x), | ||||
|                 expected_max_pool_output_padding_1_stride_2_kernel_3, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 nn.AvgPool1d(kernel_size=2, stride=1, padding=0)(x), | ||||
|                 expected_avg_pool_output_no_padding_stride_1, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 nn.AvgPool1d(kernel_size=2, stride=2, padding=0)(x), | ||||
|                 expected_avg_pool_output_no_padding_stride_2, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 nn.AvgPool1d(kernel_size=2, stride=2, padding=1)(x), | ||||
|                 expected_avg_pool_output_padding_1_stride_2, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 nn.AvgPool1d(kernel_size=3, stride=2, padding=1)(x), | ||||
|                 expected_avg_pool_output_padding_1_kernel_3, | ||||
|             ) | ||||
|         ) | ||||
|         # Test 2d pooling | ||||
|         x = mx.array( | ||||
|             [ | ||||
|                 [ | ||||
|                     [[0, 16], [1, 17], [2, 18], [3, 19]], | ||||
|                     [[4, 20], [5, 21], [6, 22], [7, 23]], | ||||
|                     [[8, 24], [9, 25], [10, 26], [11, 27]], | ||||
|                     [[12, 28], [13, 29], [14, 30], [15, 31]], | ||||
|                 ] | ||||
|             ] | ||||
|         ) | ||||
|         expected_max_pool_output_no_padding_stride_1 = [ | ||||
|             [ | ||||
|                 [[5, 21], [6, 22], [7, 23]], | ||||
|                 [[9, 25], [10, 26], [11, 27]], | ||||
|                 [[13, 29], [14, 30], [15, 31]], | ||||
|             ] | ||||
|         ] | ||||
|         expected_max_pool_output_no_padding_stride_2 = [ | ||||
|             [[[5, 21], [7, 23]], [[13, 29], [15, 31]]] | ||||
|         ] | ||||
|         expected_max_pool_output_padding_1 = [ | ||||
|             [ | ||||
|                 [[0, 16], [2, 18], [3, 19]], | ||||
|                 [[8, 24], [10, 26], [11, 27]], | ||||
|                 [[12, 28], [14, 30], [15, 31]], | ||||
|             ] | ||||
|         ] | ||||
|         expected_mean_pool_output_no_padding_stride_1 = [ | ||||
|             [ | ||||
|                 [[2.5000, 18.5000], [3.5000, 19.5000], [4.5000, 20.5000]], | ||||
|                 [[6.5000, 22.5000], [7.5000, 23.5000], [8.5000, 24.5000]], | ||||
|                 [[10.5000, 26.5000], [11.5000, 27.5000], [12.5000, 28.5000]], | ||||
|             ] | ||||
|         ] | ||||
|         expected_mean_pool_output_no_padding_stride_2 = [ | ||||
|             [ | ||||
|                 [[2.5000, 18.5000], [4.5000, 20.5000]], | ||||
|                 [[10.5000, 26.5000], [12.5000, 28.5000]], | ||||
|             ] | ||||
|         ] | ||||
|         expected_mean_pool_output_padding_1 = [ | ||||
|             [ | ||||
|                 [[0.0000, 4.0000], [0.7500, 8.7500], [0.7500, 4.7500]], | ||||
|                 [[3.0000, 11.0000], [7.5000, 23.5000], [4.5000, 12.5000]], | ||||
|                 [[3.0000, 7.0000], [6.7500, 14.7500], [3.7500, 7.7500]], | ||||
|             ] | ||||
|         ] | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.MaxPool2d(kernel_size=2, stride=1, padding=0)(x), | ||||
|                 expected_max_pool_output_no_padding_stride_1, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(x), | ||||
|                 expected_max_pool_output_no_padding_stride_2, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(x), | ||||
|                 expected_max_pool_output_padding_1, | ||||
|             ) | ||||
|         ) | ||||
|         # Average pooling | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 nn.AvgPool2d(kernel_size=2, stride=1, padding=0)(x), | ||||
|                 expected_mean_pool_output_no_padding_stride_1, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(x), | ||||
|                 expected_mean_pool_output_no_padding_stride_2, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.AvgPool2d(kernel_size=2, stride=2, padding=1)(x), | ||||
|                 expected_mean_pool_output_padding_1, | ||||
|             ) | ||||
|         ) | ||||
|         # Test multiple batches | ||||
|         x = mx.array( | ||||
|             [ | ||||
|                 [ | ||||
|                     [[0, 1], [2, 3], [4, 5], [6, 7]], | ||||
|                     [[8, 9], [10, 11], [12, 13], [14, 15]], | ||||
|                     [[16, 17], [18, 19], [20, 21], [22, 23]], | ||||
|                     [[24, 25], [26, 27], [28, 29], [30, 31]], | ||||
|                 ], | ||||
|                 [ | ||||
|                     [[32, 33], [34, 35], [36, 37], [38, 39]], | ||||
|                     [[40, 41], [42, 43], [44, 45], [46, 47]], | ||||
|                     [[48, 49], [50, 51], [52, 53], [54, 55]], | ||||
|                     [[56, 57], [58, 59], [60, 61], [62, 63]], | ||||
|                 ], | ||||
|             ] | ||||
|         ) | ||||
|         expected_max_pool_output = [ | ||||
|             [[[10.0, 11.0], [14.0, 15.0]], [[26.0, 27.0], [30.0, 31.0]]], | ||||
|             [[[42.0, 43.0], [46.0, 47.0]], [[58.0, 59.0], [62.0, 63.0]]], | ||||
|         ] | ||||
|         expected_avg_pool_output = [ | ||||
|             [[[2.22222, 2.66667], [5.33333, 6]], [[11.3333, 12], [20, 21]]], | ||||
|             [[[16.4444, 16.8889], [26.6667, 27.3333]], [[32.6667, 33.3333], [52, 53]]], | ||||
|         ] | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x), | ||||
|                 expected_max_pool_output, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 nn.AvgPool2d(kernel_size=3, stride=2, padding=1)(x), | ||||
|                 expected_avg_pool_output, | ||||
|             ) | ||||
|         ) | ||||
|         # Test irregular kernel (2, 4), stride (3, 1) and padding (1, 2) | ||||
|         x = mx.array( | ||||
|             [ | ||||
|                 [ | ||||
|                     [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], | ||||
|                     [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]], | ||||
|                     [[24, 25, 26], [27, 28, 29], [30, 31, 32], [33, 34, 35]], | ||||
|                     [[36, 37, 38], [39, 40, 41], [42, 43, 44], [45, 46, 47]], | ||||
|                 ], | ||||
|                 [ | ||||
|                     [[48, 49, 50], [51, 52, 53], [54, 55, 56], [57, 58, 59]], | ||||
|                     [[60, 61, 62], [63, 64, 65], [66, 67, 68], [69, 70, 71]], | ||||
|                     [[72, 73, 74], [75, 76, 77], [78, 79, 80], [81, 82, 83]], | ||||
|                     [[84, 85, 86], [87, 88, 89], [90, 91, 92], [93, 94, 95]], | ||||
|                 ], | ||||
|             ] | ||||
|         ) | ||||
|         expected_irregular_max_pool_output = [ | ||||
|             [ | ||||
|                 [ | ||||
|                     [3.0, 4.0, 5.0], | ||||
|                     [6.0, 7.0, 8.0], | ||||
|                     [9.0, 10.0, 11.0], | ||||
|                     [9.0, 10.0, 11.0], | ||||
|                     [9.0, 10.0, 11.0], | ||||
|                 ], | ||||
|                 [ | ||||
|                     [39.0, 40.0, 41.0], | ||||
|                     [42.0, 43.0, 44.0], | ||||
|                     [45.0, 46.0, 47.0], | ||||
|                     [45.0, 46.0, 47.0], | ||||
|                     [45.0, 46.0, 47.0], | ||||
|                 ], | ||||
|             ], | ||||
|             [ | ||||
|                 [ | ||||
|                     [51.0, 52.0, 53.0], | ||||
|                     [54.0, 55.0, 56.0], | ||||
|                     [57.0, 58.0, 59.0], | ||||
|                     [57.0, 58.0, 59.0], | ||||
|                     [57.0, 58.0, 59.0], | ||||
|                 ], | ||||
|                 [ | ||||
|                     [87.0, 88.0, 89.0], | ||||
|                     [90.0, 91.0, 92.0], | ||||
|                     [93.0, 94.0, 95.0], | ||||
|                     [93.0, 94.0, 95.0], | ||||
|                     [93.0, 94.0, 95.0], | ||||
|                 ], | ||||
|             ], | ||||
|         ] | ||||
|         expected_irregular_average_pool_output = [ | ||||
|             [ | ||||
|                 [ | ||||
|                     [0.3750, 0.6250, 0.8750], | ||||
|                     [1.1250, 1.5000, 1.8750], | ||||
|                     [2.2500, 2.7500, 3.2500], | ||||
|                     [2.2500, 2.6250, 3.0000], | ||||
|                     [1.8750, 2.1250, 2.3750], | ||||
|                 ], | ||||
|                 [ | ||||
|                     [15.7500, 16.2500, 16.7500], | ||||
|                     [24.7500, 25.5000, 26.2500], | ||||
|                     [34.5000, 35.5000, 36.5000], | ||||
|                     [27.0000, 27.7500, 28.5000], | ||||
|                     [18.7500, 19.2500, 19.7500], | ||||
|                 ], | ||||
|             ], | ||||
|             [ | ||||
|                 [ | ||||
|                     [12.3750, 12.6250, 12.8750], | ||||
|                     [19.1250, 19.5000, 19.8750], | ||||
|                     [26.2500, 26.7500, 27.2500], | ||||
|                     [20.2500, 20.6250, 21.0000], | ||||
|                     [13.8750, 14.1250, 14.3750], | ||||
|                 ], | ||||
|                 [ | ||||
|                     [39.7500, 40.2500, 40.7500], | ||||
|                     [60.7500, 61.5000, 62.2500], | ||||
|                     [82.5000, 83.5000, 84.5000], | ||||
|                     [63.0000, 63.7500, 64.5000], | ||||
|                     [42.7500, 43.2500, 43.7500], | ||||
|                 ], | ||||
|             ], | ||||
|         ] | ||||
|         self.assertTrue( | ||||
|             np.array_equal( | ||||
|                 nn.MaxPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x), | ||||
|                 expected_irregular_max_pool_output, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 nn.AvgPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x), | ||||
|                 expected_irregular_average_pool_output, | ||||
|             ) | ||||
|         ) | ||||
|         # Test repr | ||||
|         self.assertEqual( | ||||
|             str(nn.MaxPool1d(kernel_size=3, padding=2)), | ||||
|             "MaxPool1d(kernel_size=(3,), stride=(3,), padding=(2,))", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             str(nn.AvgPool1d(kernel_size=2, stride=3)), | ||||
|             "AvgPool1d(kernel_size=(2,), stride=(3,), padding=(0,))", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             str(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), | ||||
|             "MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))), | ||||
|             "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", | ||||
|         ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user