mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			811 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			811 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023-2024 Apple Inc.
 | |
| 
 | |
| import math
 | |
| import unittest
 | |
| from itertools import permutations
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx_tests
 | |
| import numpy as np
 | |
| 
 | |
| try:
 | |
|     import torch
 | |
|     import torch.nn.functional as F
 | |
| 
 | |
|     has_torch = True
 | |
| except ImportError as e:
 | |
|     has_torch = False
 | |
| 
 | |
| 
 | |
| class TestConvTranspose(mlx_tests.MLXTestCase):
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_conv_transpose_1D(self):
 | |
|         def run_conv_transpose_1D(
 | |
|             N,
 | |
|             C,
 | |
|             O,
 | |
|             iH,
 | |
|             kH,
 | |
|             stride,
 | |
|             padding,
 | |
|             output_padding=0,
 | |
|             dilation=1,
 | |
|             groups=1,
 | |
|             dtype="float32",
 | |
|             atol=1e-5,
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 dtype=dtype,
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 O=O,
 | |
|                 iH=iH,
 | |
|                 kH=kH,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 dilation=dilation,
 | |
|                 groups=groups,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
 | |
|                 wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
 | |
|                     np_dtype
 | |
|                 )
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
 | |
| 
 | |
|                 out_mx = mx.conv_transpose1d(
 | |
|                     in_mx,
 | |
|                     wt_mx,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     dilation=dilation,
 | |
|                     groups=groups,
 | |
|                 )
 | |
|                 out_pt = torch.conv_transpose1d(
 | |
|                     in_pt,
 | |
|                     wt_pt,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     dilation=dilation,
 | |
|                     groups=groups,
 | |
|                 )
 | |
|                 out_pt = torch.transpose(out_pt, 2, 1)
 | |
| 
 | |
|                 self.assertEqual(out_pt.shape, out_mx.shape)
 | |
|                 self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C, O in (
 | |
|                 (1, 1, 1),
 | |
|                 (1, 6, 1),
 | |
|                 (1, 1, 6),
 | |
|                 (4, 32, 64),
 | |
|             ):
 | |
|                 for iH, kH, stride, padding in (
 | |
|                     (1, 1, 1, 0),
 | |
|                     (3, 3, 1, 0),
 | |
|                     (31, 5, 5, 2),
 | |
|                 ):
 | |
|                     run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
 | |
| 
 | |
|         # Groups tests
 | |
|         N, C, O = (4, 32, 64)
 | |
|         for iH, kH, stride, padding in (
 | |
|             (1, 1, 1, 0),
 | |
|             (3, 3, 1, 0),
 | |
|             (31, 5, 5, 2),
 | |
|         ):
 | |
|             for group in (1,):
 | |
|                 run_conv_transpose_1D(
 | |
|                     N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype
 | |
|                 )
 | |
| 
 | |
|         # Strided inputs tests
 | |
|         for tpose_in, tpose_wt in (
 | |
|             ((0, 2, 1), (0, 1, 2)),
 | |
|             ((0, 2, 1), (0, 2, 1)),
 | |
|         ):
 | |
|             with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
 | |
|                 in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
 | |
|                 wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_mx_t = mx.transpose(in_mx, tpose_in)
 | |
|                 wt_mx_t = mx.transpose(wt_mx, tpose_wt)
 | |
|                 out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t)
 | |
| 
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1))
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1))
 | |
| 
 | |
|                 out_pt = torch.conv_transpose1d(in_pt, wt_pt)
 | |
|                 out_pt = torch.transpose(out_pt, 2, 1)
 | |
| 
 | |
|                 self.assertEqual(out_pt.shape, out_mx.shape)
 | |
|                 self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
 | |
| 
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_conv_transpose_1D_grad(self):
 | |
|         def run_conv_transpose1D_grad(
 | |
|             N,
 | |
|             C,
 | |
|             O,
 | |
|             iH,
 | |
|             kH,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation=1,
 | |
|             groups=1,
 | |
|             dtype="float32",
 | |
|             atol=1e-5,
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 dtype=dtype,
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 O=O,
 | |
|                 iH=iH,
 | |
|                 kH=kH,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 dilation=dilation,
 | |
|                 groups=groups,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 # oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
 | |
| 
 | |
|                 in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
 | |
|                 wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True)
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True)
 | |
| 
 | |
|                 out_pt = F.conv_transpose1d(
 | |
|                     in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
 | |
|                 )
 | |
| 
 | |
|                 # use torch to compute ct
 | |
|                 out_pt.retain_grad()
 | |
|                 out_pt.sum().backward()
 | |
| 
 | |
|                 pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
 | |
|                 pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
 | |
| 
 | |
|                 ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1))
 | |
| 
 | |
|                 def f(a, b):
 | |
|                     return mx.conv_transpose1d(
 | |
|                         a,
 | |
|                         b,
 | |
|                         stride=stride,
 | |
|                         padding=padding,
 | |
|                         dilation=dilation,
 | |
|                         groups=groups,
 | |
|                     )
 | |
| 
 | |
|                 _, outs_mx = mx.vjp(
 | |
|                     f,
 | |
|                     [
 | |
|                         in_mx,
 | |
|                         wt_mx,
 | |
|                     ],
 | |
|                     [
 | |
|                         ct_mx,
 | |
|                     ],
 | |
|                 )
 | |
| 
 | |
|                 mx_grad_in, mx_grad_wt = outs_mx
 | |
| 
 | |
|                 self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
 | |
|                 self.assertEqual(in_mx.shape, mx_grad_in.shape)
 | |
|                 self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
 | |
| 
 | |
|                 self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
 | |
|                 self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
 | |
|                 self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C, O in (
 | |
|                 (1, 1, 1),
 | |
|                 (1, 6, 1),
 | |
|                 (1, 1, 6),
 | |
|                 (4, 32, 64),
 | |
|             ):
 | |
|                 for iH, kH, stride, padding in (
 | |
|                     (1, 1, 1, 0),
 | |
|                     (3, 3, 1, 0),
 | |
|                     (31, 5, 5, 2),
 | |
|                 ):
 | |
|                     run_conv_transpose1D_grad(
 | |
|                         N, C, O, iH, kH, stride, padding, dtype=dtype
 | |
|                     )
 | |
| 
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_conv_transpose_2D(self):
 | |
|         def run_conv_transpose2D(
 | |
|             N,
 | |
|             C,
 | |
|             O,
 | |
|             idim,
 | |
|             kdim,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation=(1, 1),
 | |
|             groups=1,
 | |
|             dtype="float32",
 | |
|             atol=1e-5,
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 dtype=dtype,
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 O=O,
 | |
|                 idim=idim,
 | |
|                 kdim=kdim,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 dilation=dilation,
 | |
|                 groups=groups,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 iH, iW = idim
 | |
|                 kH, kW = kdim
 | |
|                 scale = 1.0 / math.sqrt(kH * kW * C)
 | |
|                 in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
 | |
|                 wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
 | |
|                     np_dtype
 | |
|                 )
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to("cpu")
 | |
| 
 | |
|                 out_mx = mx.conv_transpose2d(
 | |
|                     in_mx,
 | |
|                     wt_mx,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     dilation=dilation,
 | |
|                     groups=groups,
 | |
|                 )
 | |
|                 out_pt = torch.conv_transpose2d(
 | |
|                     in_pt,
 | |
|                     wt_pt,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     dilation=dilation,
 | |
|                     groups=groups,
 | |
|                 )
 | |
|                 out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
 | |
| 
 | |
|                 self.assertEqual(out_pt.shape, out_mx.shape)
 | |
|                 self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C, O in (
 | |
|                 (1, 1, 1),
 | |
|                 (1, 6, 1),
 | |
|                 (1, 1, 6),
 | |
|                 (4, 32, 64),
 | |
|             ):
 | |
|                 for idim, kdim, stride, padding in (
 | |
|                     ((1, 1), (1, 1), (1, 1), (0, 0)),
 | |
|                     ((3, 3), (3, 1), (1, 1), (0, 0)),
 | |
|                     ((31, 31), (5, 5), (5, 5), (2, 2)),
 | |
|                 ):
 | |
|                     run_conv_transpose2D(
 | |
|                         N, C, O, idim, kdim, stride, padding, dtype=dtype
 | |
|                     )
 | |
| 
 | |
|             # Groups tests
 | |
|             N, C, O = (4, 32, 64)
 | |
|             for idim, kdim, stride, padding in (
 | |
|                 ((1, 1), (1, 1), (1, 1), (0, 0)),
 | |
|                 ((3, 3), (3, 1), (1, 1), (0, 0)),
 | |
|                 ((31, 31), (5, 5), (5, 5), (2, 2)),
 | |
|             ):
 | |
|                 for group in (1,):
 | |
|                     run_conv_transpose2D(
 | |
|                         N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
 | |
|                     )
 | |
| 
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_conv_transpose_2D_grad(self):
 | |
|         def run_conv_transpose2D_grad(
 | |
|             N,
 | |
|             C,
 | |
|             O,
 | |
|             idim,
 | |
|             kdim,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation=(1, 1),
 | |
|             groups=1,
 | |
|             dtype="float32",
 | |
|             atol=1e-5,
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 dtype=dtype,
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 O=O,
 | |
|                 idim=idim,
 | |
|                 kdim=kdim,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 dilation=dilation,
 | |
|                 groups=groups,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 iH, iW = idim
 | |
|                 kH, kW = kdim
 | |
|                 scale = 1.0 / math.sqrt(kH * kW * C * O)
 | |
| 
 | |
|                 in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
 | |
|                 wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_(
 | |
|                     True
 | |
|                 )
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_(
 | |
|                     True
 | |
|                 )
 | |
| 
 | |
|                 out_pt = F.conv_transpose2d(
 | |
|                     in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
 | |
|                 )
 | |
| 
 | |
|                 # use torch to compute ct
 | |
|                 out_pt.retain_grad()
 | |
|                 out_pt.sum().backward()
 | |
| 
 | |
|                 pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
 | |
|                 pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
 | |
| 
 | |
|                 ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1))
 | |
| 
 | |
|                 def f(a, b):
 | |
|                     return mx.conv_transpose2d(
 | |
|                         a,
 | |
|                         b,
 | |
|                         stride=stride,
 | |
|                         padding=padding,
 | |
|                         dilation=dilation,
 | |
|                         groups=groups,
 | |
|                     )
 | |
| 
 | |
|                 _, outs_mx = mx.vjp(
 | |
|                     f,
 | |
|                     [in_mx, wt_mx],
 | |
|                     [ct_mx],
 | |
|                 )
 | |
| 
 | |
|                 mx_grad_in, mx_grad_wt = outs_mx
 | |
| 
 | |
|                 self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
 | |
|                 self.assertEqual(in_mx.shape, mx_grad_in.shape)
 | |
|                 self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
 | |
| 
 | |
|                 self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
 | |
|                 self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
 | |
|                 self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
 | |
|                 for idim, kdim, stride, padding, dilation in (
 | |
|                     ((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
 | |
|                     ((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
 | |
|                     ((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
 | |
|                     ((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
 | |
|                     ((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
 | |
|                     ((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
 | |
|                 ):
 | |
|                     run_conv_transpose2D_grad(
 | |
|                         N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
 | |
|                     )
 | |
| 
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_conv_transpose_3D(self):
 | |
|         def run_conv_transpose3D(
 | |
|             N,
 | |
|             C,
 | |
|             O,
 | |
|             idim,
 | |
|             kdim,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation=(1, 1, 1),
 | |
|             groups=1,
 | |
|             dtype="float32",
 | |
|             atol=1e-5,
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 dtype=dtype,
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 O=O,
 | |
|                 idim=idim,
 | |
|                 kdim=kdim,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 dilation=dilation,
 | |
|                 groups=groups,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 iD, iH, iW = idim
 | |
|                 kD, kH, kW = kdim
 | |
|                 scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
 | |
|                 in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
 | |
|                     np_dtype
 | |
|                 )
 | |
|                 wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
 | |
| 
 | |
|                 out_mx = mx.conv_transpose3d(
 | |
|                     in_mx,
 | |
|                     wt_mx,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     dilation=dilation,
 | |
|                     groups=groups,
 | |
|                 )
 | |
|                 out_pt = torch.conv_transpose3d(
 | |
|                     in_pt,
 | |
|                     wt_pt,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     dilation=dilation,
 | |
|                     groups=groups,
 | |
|                 )
 | |
|                 out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
 | |
| 
 | |
|                 self.assertEqual(out_pt.shape, out_mx.shape)
 | |
|                 self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C, O in (
 | |
|                 (1, 1, 1),
 | |
|                 (1, 6, 1),
 | |
|                 (1, 1, 6),
 | |
|                 (2, 8, 16),
 | |
|             ):
 | |
|                 for idim, kdim, stride, padding in (
 | |
|                     ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
 | |
|                     ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
 | |
|                     ((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)),
 | |
|                 ):
 | |
|                     run_conv_transpose3D(
 | |
|                         N, C, O, idim, kdim, stride, padding, dtype=dtype
 | |
|                     )
 | |
| 
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_conv_transpose_3D_grad(self):
 | |
|         def run_conv_transpose3D_grad(
 | |
|             N,
 | |
|             C,
 | |
|             O,
 | |
|             idim,
 | |
|             kdim,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation=(1, 1, 1),
 | |
|             groups=1,
 | |
|             dtype="float32",
 | |
|             atol=1e-4,
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 dtype=dtype,
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 O=O,
 | |
|                 idim=idim,
 | |
|                 kdim=kdim,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 dilation=dilation,
 | |
|                 groups=groups,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 iD, iH, iW = idim
 | |
|                 kD, kH, kW = kdim
 | |
|                 scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
 | |
| 
 | |
|                 in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
 | |
|                     np_dtype
 | |
|                 )
 | |
|                 wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
 | |
|                     np_dtype
 | |
|                 )
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_(
 | |
|                     True
 | |
|                 )
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_(
 | |
|                     True
 | |
|                 )
 | |
| 
 | |
|                 out_pt = F.conv_transpose3d(
 | |
|                     in_pt,
 | |
|                     wt_pt,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     dilation=dilation,
 | |
|                     groups=groups,
 | |
|                 )
 | |
| 
 | |
|                 # use torch to compute ct
 | |
|                 out_pt.retain_grad()
 | |
|                 out_pt.sum().backward()
 | |
| 
 | |
|                 pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
 | |
|                 pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
 | |
| 
 | |
|                 ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1))
 | |
| 
 | |
|                 def f(a, b):
 | |
|                     return mx.conv_transpose3d(
 | |
|                         a,
 | |
|                         b,
 | |
|                         stride=stride,
 | |
|                         padding=padding,
 | |
|                         dilation=dilation,
 | |
|                         groups=groups,
 | |
|                     )
 | |
| 
 | |
|                 _, outs_mx = mx.vjp(
 | |
|                     f,
 | |
|                     [in_mx, wt_mx],
 | |
|                     [ct_mx],
 | |
|                 )
 | |
| 
 | |
|                 mx_grad_in, mx_grad_wt = outs_mx
 | |
| 
 | |
|                 self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
 | |
|                 self.assertEqual(in_mx.shape, mx_grad_in.shape)
 | |
|                 self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
 | |
| 
 | |
|                 self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
 | |
|                 self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
 | |
|                 self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)):
 | |
|                 for idim, kdim, stride, padding, dilation in (
 | |
|                     ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
 | |
|                     ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
 | |
|                     ((7, 7, 7), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
 | |
|                     ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
 | |
|                     ((7, 7, 7), (5, 5, 5), (3, 3, 3), (2, 2, 2), (3, 2, 2)),
 | |
|                     ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
 | |
|                 ):
 | |
|                     run_conv_transpose3D_grad(
 | |
|                         N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
 | |
|                     )
 | |
| 
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_conv_tranpose_1d_output_padding(self):
 | |
|         def run_conv_transpose_1d_output_padding(
 | |
|             N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 dtype=dtype,
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 O=O,
 | |
|                 iH=iH,
 | |
|                 kH=kH,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 output_padding=output_padding,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
 | |
|                 wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
 | |
| 
 | |
|                 out_mx = mx.conv_transpose1d(
 | |
|                     in_mx,
 | |
|                     wt_mx,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     output_padding=output_padding,
 | |
|                 )
 | |
| 
 | |
|                 out_pt = torch.conv_transpose1d(
 | |
|                     in_pt,
 | |
|                     wt_pt,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     output_padding=output_padding,
 | |
|                 )
 | |
|                 out_pt = torch.transpose(out_pt, 2, 1)
 | |
| 
 | |
|                 self.assertEqual(out_pt.shape, out_mx.shape)
 | |
|                 self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
 | |
|                 for iH, kH, stride, padding, output_padding in (
 | |
|                     (3, 2, 2, 0, 1),
 | |
|                     (5, 3, 2, 1, 0),
 | |
|                     (7, 4, 3, 1, 2),
 | |
|                 ):
 | |
|                     run_conv_transpose_1d_output_padding(
 | |
|                         N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype
 | |
|                     )
 | |
| 
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_conv_transpose_2d_output_padding(self):
 | |
|         def run_conv_transpose_2d_output_padding(
 | |
|             N,
 | |
|             C,
 | |
|             O,
 | |
|             idim,
 | |
|             kdim,
 | |
|             stride,
 | |
|             padding,
 | |
|             output_padding,
 | |
|             dtype="float32",
 | |
|             atol=1e-5,
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 dtype=dtype,
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 O=O,
 | |
|                 idim=idim,
 | |
|                 kdim=kdim,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 output_padding=output_padding,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 iH, iW = idim
 | |
|                 kH, kW = kdim
 | |
|                 in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype)
 | |
|                 wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype)
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2))
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2))
 | |
| 
 | |
|                 out_mx = mx.conv_transpose2d(
 | |
|                     in_mx,
 | |
|                     wt_mx,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     output_padding=output_padding,
 | |
|                 )
 | |
| 
 | |
|                 out_pt = torch.conv_transpose2d(
 | |
|                     in_pt,
 | |
|                     wt_pt,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     output_padding=output_padding,
 | |
|                 )
 | |
|                 out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
 | |
| 
 | |
|                 self.assertEqual(out_pt.shape, out_mx.shape)
 | |
|                 self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
 | |
|                 for idim, kdim, stride, padding, output_padding in (
 | |
|                     ((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)),
 | |
|                     ((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)),
 | |
|                     ((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)),
 | |
|                 ):
 | |
|                     run_conv_transpose_2d_output_padding(
 | |
|                         N,
 | |
|                         C,
 | |
|                         O,
 | |
|                         idim,
 | |
|                         kdim,
 | |
|                         stride,
 | |
|                         padding,
 | |
|                         output_padding,
 | |
|                         dtype=dtype,
 | |
|                     )
 | |
| 
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_conv_transpose_3d_output_padding(self):
 | |
|         def run_conv_transpose_3d_output_padding(
 | |
|             N,
 | |
|             C,
 | |
|             O,
 | |
|             idim,
 | |
|             kdim,
 | |
|             stride,
 | |
|             padding,
 | |
|             output_padding,
 | |
|             dtype="float32",
 | |
|             atol=1e-5,
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 dtype=dtype,
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 O=O,
 | |
|                 idim=idim,
 | |
|                 kdim=kdim,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 output_padding=output_padding,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 iD, iH, iW = idim
 | |
|                 kD, kH, kW = kdim
 | |
|                 in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype(
 | |
|                     np_dtype
 | |
|                 )
 | |
|                 wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype(
 | |
|                     np_dtype
 | |
|                 )
 | |
| 
 | |
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
 | |
|                 wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
 | |
| 
 | |
|                 out_mx = mx.conv_transpose3d(
 | |
|                     in_mx,
 | |
|                     wt_mx,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     output_padding=output_padding,
 | |
|                 )
 | |
|                 out_pt = torch.conv_transpose3d(
 | |
|                     in_pt,
 | |
|                     wt_pt,
 | |
|                     stride=stride,
 | |
|                     padding=padding,
 | |
|                     output_padding=output_padding,
 | |
|                 )
 | |
|                 out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
 | |
| 
 | |
|                 self.assertEqual(out_pt.shape, out_mx.shape)
 | |
|                 self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
 | |
|                 for idim, kdim, stride, padding, output_padding in (
 | |
|                     ((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)),
 | |
|                     ((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)),
 | |
|                     ((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)),
 | |
|                 ):
 | |
|                     run_conv_transpose_3d_output_padding(
 | |
|                         N,
 | |
|                         C,
 | |
|                         O,
 | |
|                         idim,
 | |
|                         kdim,
 | |
|                         stride,
 | |
|                         padding,
 | |
|                         output_padding,
 | |
|                         dtype=dtype,
 | |
|                     )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     mlx_tests.MLXTestRunner()
 | 
