mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	add bfloat conv for windograd (#1306)
* add bfloat conv for windograd * accumulate in fp32 * accumulate in fp32 * accumulate in bf16
This commit is contained in:
		| @@ -275,7 +275,6 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|             dilation=(1, 1), | ||||
|             groups=1, | ||||
|             dtype="float32", | ||||
|             atol=1e-5, | ||||
|         ): | ||||
|             with self.subTest( | ||||
|                 dtype=dtype, | ||||
| @@ -289,19 +288,22 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                 dilation=dilation, | ||||
|                 groups=groups, | ||||
|             ): | ||||
|                 np_dtype = getattr(np, dtype) | ||||
|                 np.random.seed(0) | ||||
|                 iH, iW = idim | ||||
|                 kH, kW = kdim | ||||
|                 scale = 1.0 / math.sqrt(kH * kW * C) | ||||
|                 in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) | ||||
|                 wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype( | ||||
|                     np_dtype | ||||
|                 ) | ||||
|                 in_np = np.random.normal(0.0, scale, (N, iH, iW, C)) | ||||
|                 wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))) | ||||
|  | ||||
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np)) | ||||
|                 mx_dtype = getattr(mx, dtype) | ||||
|                 torch_dtype = getattr(torch, dtype) | ||||
|                 in_mx, wt_mx = map( | ||||
|                     lambda x: mx.array(x).astype(mx_dtype), (in_np, wt_np) | ||||
|                 ) | ||||
|                 in_pt, wt_pt = map( | ||||
|                     lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"), | ||||
|                     lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)) | ||||
|                     .to("cpu") | ||||
|                     .to(torch_dtype), | ||||
|                     (in_np, wt_np), | ||||
|                 ) | ||||
|  | ||||
| @@ -312,7 +314,7 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                     padding=padding, | ||||
|                     dilation=dilation, | ||||
|                     groups=groups, | ||||
|                 ) | ||||
|                 ).astype(mx.float32) | ||||
|                 out_pt = torch.conv2d( | ||||
|                     in_pt, | ||||
|                     wt_pt, | ||||
| @@ -321,12 +323,20 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                     dilation=dilation, | ||||
|                     groups=groups, | ||||
|                 ) | ||||
|                 out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) | ||||
|                 out_pt = ( | ||||
|                     torch.permute(out_pt, (0, 2, 3, 1)) | ||||
|                     .to(torch.float32) | ||||
|                     .numpy(force=True) | ||||
|                 ) | ||||
|  | ||||
|                 self.assertEqual(out_pt.shape, out_mx.shape) | ||||
|                 if dtype == "bfloat16": | ||||
|                     atol, rtol = 1e-1, 1e-3 | ||||
|                 else: | ||||
|                     atol, rtol = 1e-5, 1e-6 | ||||
|                 self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) | ||||
|  | ||||
|         for dtype in ("float32",): | ||||
|         for dtype in ("float32", "bfloat16"): | ||||
|             for N, C, O in ( | ||||
|                 (1, 1, 1), | ||||
|                 (1, 6, 1), | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun