mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	fix: conv_general differences between gpu, cpu (#2070)
* fix general_conv padding * fix bugs * add test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                         atol=2e-5 if dtype == np.float32 else 5e-4, | ||||
|                     ) | ||||
|  | ||||
|     @unittest.skipIf(not has_torch, "requires Torch") | ||||
|     def test_asymmetric_padding(self): | ||||
|         inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32) | ||||
|         kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32) | ||||
|         strides = (2, 2, 2) | ||||
|  | ||||
|         pt_out = torch.conv3d( | ||||
|             torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)), | ||||
|             torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)), | ||||
|             stride=strides, | ||||
|             padding=2, | ||||
|         ) | ||||
|         pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy() | ||||
|  | ||||
|         mx_out = mx.conv_general( | ||||
|             mx.array(inputs), | ||||
|             mx.array(kernel), | ||||
|             stride=strides, | ||||
|             padding=([0, 0, 0], [1, 1, 1]), | ||||
|         ) | ||||
|  | ||||
|         self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) | ||||
|  | ||||
|         inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32) | ||||
|         kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32) | ||||
|  | ||||
|         pt_out = torch.conv2d( | ||||
|             torch.permute(torch.tensor(inputs), (0, 3, 1, 2)), | ||||
|             torch.permute(torch.tensor(kernel), (0, 3, 1, 2)), | ||||
|             stride=1, | ||||
|             padding=(1, 0), | ||||
|         ) | ||||
|         pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy() | ||||
|  | ||||
|         mx_out = mx.conv_general( | ||||
|             mx.array(inputs), | ||||
|             mx.array(kernel), | ||||
|             stride=1, | ||||
|             padding=([0, 0], [1, 0]), | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 ATurker
					ATurker