mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			811 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			811 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2023-2024 Apple Inc.
 | 
						|
 | 
						|
import math
 | 
						|
import unittest
 | 
						|
from itertools import permutations
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import mlx_tests
 | 
						|
import numpy as np
 | 
						|
 | 
						|
try:
 | 
						|
    import torch
 | 
						|
    import torch.nn.functional as F
 | 
						|
 | 
						|
    has_torch = True
 | 
						|
except ImportError as e:
 | 
						|
    has_torch = False
 | 
						|
 | 
						|
 | 
						|
class TestConvTranspose(mlx_tests.MLXTestCase):
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_torch_conv_transpose_1D(self):
 | 
						|
        def run_conv_transpose_1D(
 | 
						|
            N,
 | 
						|
            C,
 | 
						|
            O,
 | 
						|
            iH,
 | 
						|
            kH,
 | 
						|
            stride,
 | 
						|
            padding,
 | 
						|
            output_padding=0,
 | 
						|
            dilation=1,
 | 
						|
            groups=1,
 | 
						|
            dtype="float32",
 | 
						|
            atol=1e-5,
 | 
						|
        ):
 | 
						|
            with self.subTest(
 | 
						|
                dtype=dtype,
 | 
						|
                N=N,
 | 
						|
                C=C,
 | 
						|
                O=O,
 | 
						|
                iH=iH,
 | 
						|
                kH=kH,
 | 
						|
                stride=stride,
 | 
						|
                padding=padding,
 | 
						|
                dilation=dilation,
 | 
						|
                groups=groups,
 | 
						|
            ):
 | 
						|
                np_dtype = getattr(np, dtype)
 | 
						|
                np.random.seed(0)
 | 
						|
                in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
 | 
						|
                wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
 | 
						|
                    np_dtype
 | 
						|
                )
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
 | 
						|
 | 
						|
                out_mx = mx.conv_transpose1d(
 | 
						|
                    in_mx,
 | 
						|
                    wt_mx,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    dilation=dilation,
 | 
						|
                    groups=groups,
 | 
						|
                )
 | 
						|
                out_pt = torch.conv_transpose1d(
 | 
						|
                    in_pt,
 | 
						|
                    wt_pt,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    dilation=dilation,
 | 
						|
                    groups=groups,
 | 
						|
                )
 | 
						|
                out_pt = torch.transpose(out_pt, 2, 1)
 | 
						|
 | 
						|
                self.assertEqual(out_pt.shape, out_mx.shape)
 | 
						|
                self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
 | 
						|
 | 
						|
        for dtype in ("float32",):
 | 
						|
            for N, C, O in (
 | 
						|
                (1, 1, 1),
 | 
						|
                (1, 6, 1),
 | 
						|
                (1, 1, 6),
 | 
						|
                (4, 32, 64),
 | 
						|
            ):
 | 
						|
                for iH, kH, stride, padding in (
 | 
						|
                    (1, 1, 1, 0),
 | 
						|
                    (3, 3, 1, 0),
 | 
						|
                    (31, 5, 5, 2),
 | 
						|
                ):
 | 
						|
                    run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
 | 
						|
 | 
						|
        # Groups tests
 | 
						|
        N, C, O = (4, 32, 64)
 | 
						|
        for iH, kH, stride, padding in (
 | 
						|
            (1, 1, 1, 0),
 | 
						|
            (3, 3, 1, 0),
 | 
						|
            (31, 5, 5, 2),
 | 
						|
        ):
 | 
						|
            for group in (1,):
 | 
						|
                run_conv_transpose_1D(
 | 
						|
                    N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype
 | 
						|
                )
 | 
						|
 | 
						|
        # Strided inputs tests
 | 
						|
        for tpose_in, tpose_wt in (
 | 
						|
            ((0, 2, 1), (0, 1, 2)),
 | 
						|
            ((0, 2, 1), (0, 2, 1)),
 | 
						|
        ):
 | 
						|
            with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
 | 
						|
                in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
 | 
						|
                wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_mx_t = mx.transpose(in_mx, tpose_in)
 | 
						|
                wt_mx_t = mx.transpose(wt_mx, tpose_wt)
 | 
						|
                out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t)
 | 
						|
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1))
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1))
 | 
						|
 | 
						|
                out_pt = torch.conv_transpose1d(in_pt, wt_pt)
 | 
						|
                out_pt = torch.transpose(out_pt, 2, 1)
 | 
						|
 | 
						|
                self.assertEqual(out_pt.shape, out_mx.shape)
 | 
						|
                self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
 | 
						|
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_torch_conv_transpose_1D_grad(self):
 | 
						|
        def run_conv_transpose1D_grad(
 | 
						|
            N,
 | 
						|
            C,
 | 
						|
            O,
 | 
						|
            iH,
 | 
						|
            kH,
 | 
						|
            stride,
 | 
						|
            padding,
 | 
						|
            dilation=1,
 | 
						|
            groups=1,
 | 
						|
            dtype="float32",
 | 
						|
            atol=1e-5,
 | 
						|
        ):
 | 
						|
            with self.subTest(
 | 
						|
                dtype=dtype,
 | 
						|
                N=N,
 | 
						|
                C=C,
 | 
						|
                O=O,
 | 
						|
                iH=iH,
 | 
						|
                kH=kH,
 | 
						|
                stride=stride,
 | 
						|
                padding=padding,
 | 
						|
                dilation=dilation,
 | 
						|
                groups=groups,
 | 
						|
            ):
 | 
						|
                np_dtype = getattr(np, dtype)
 | 
						|
                np.random.seed(0)
 | 
						|
                # oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
 | 
						|
 | 
						|
                in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
 | 
						|
                wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True)
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True)
 | 
						|
 | 
						|
                out_pt = F.conv_transpose1d(
 | 
						|
                    in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
 | 
						|
                )
 | 
						|
 | 
						|
                # use torch to compute ct
 | 
						|
                out_pt.retain_grad()
 | 
						|
                out_pt.sum().backward()
 | 
						|
 | 
						|
                pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
 | 
						|
                pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
 | 
						|
 | 
						|
                ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1))
 | 
						|
 | 
						|
                def f(a, b):
 | 
						|
                    return mx.conv_transpose1d(
 | 
						|
                        a,
 | 
						|
                        b,
 | 
						|
                        stride=stride,
 | 
						|
                        padding=padding,
 | 
						|
                        dilation=dilation,
 | 
						|
                        groups=groups,
 | 
						|
                    )
 | 
						|
 | 
						|
                _, outs_mx = mx.vjp(
 | 
						|
                    f,
 | 
						|
                    [
 | 
						|
                        in_mx,
 | 
						|
                        wt_mx,
 | 
						|
                    ],
 | 
						|
                    [
 | 
						|
                        ct_mx,
 | 
						|
                    ],
 | 
						|
                )
 | 
						|
 | 
						|
                mx_grad_in, mx_grad_wt = outs_mx
 | 
						|
 | 
						|
                self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
 | 
						|
                self.assertEqual(in_mx.shape, mx_grad_in.shape)
 | 
						|
                self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
 | 
						|
 | 
						|
                self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
 | 
						|
                self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
 | 
						|
                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
 | 
						|
 | 
						|
        for dtype in ("float32",):
 | 
						|
            for N, C, O in (
 | 
						|
                (1, 1, 1),
 | 
						|
                (1, 6, 1),
 | 
						|
                (1, 1, 6),
 | 
						|
                (4, 32, 64),
 | 
						|
            ):
 | 
						|
                for iH, kH, stride, padding in (
 | 
						|
                    (1, 1, 1, 0),
 | 
						|
                    (3, 3, 1, 0),
 | 
						|
                    (31, 5, 5, 2),
 | 
						|
                ):
 | 
						|
                    run_conv_transpose1D_grad(
 | 
						|
                        N, C, O, iH, kH, stride, padding, dtype=dtype
 | 
						|
                    )
 | 
						|
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_torch_conv_transpose_2D(self):
 | 
						|
        def run_conv_transpose2D(
 | 
						|
            N,
 | 
						|
            C,
 | 
						|
            O,
 | 
						|
            idim,
 | 
						|
            kdim,
 | 
						|
            stride,
 | 
						|
            padding,
 | 
						|
            dilation=(1, 1),
 | 
						|
            groups=1,
 | 
						|
            dtype="float32",
 | 
						|
            atol=1e-5,
 | 
						|
        ):
 | 
						|
            with self.subTest(
 | 
						|
                dtype=dtype,
 | 
						|
                N=N,
 | 
						|
                C=C,
 | 
						|
                O=O,
 | 
						|
                idim=idim,
 | 
						|
                kdim=kdim,
 | 
						|
                stride=stride,
 | 
						|
                padding=padding,
 | 
						|
                dilation=dilation,
 | 
						|
                groups=groups,
 | 
						|
            ):
 | 
						|
                np_dtype = getattr(np, dtype)
 | 
						|
                np.random.seed(0)
 | 
						|
                iH, iW = idim
 | 
						|
                kH, kW = kdim
 | 
						|
                scale = 1.0 / math.sqrt(kH * kW * C)
 | 
						|
                in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
 | 
						|
                wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
 | 
						|
                    np_dtype
 | 
						|
                )
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to("cpu")
 | 
						|
 | 
						|
                out_mx = mx.conv_transpose2d(
 | 
						|
                    in_mx,
 | 
						|
                    wt_mx,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    dilation=dilation,
 | 
						|
                    groups=groups,
 | 
						|
                )
 | 
						|
                out_pt = torch.conv_transpose2d(
 | 
						|
                    in_pt,
 | 
						|
                    wt_pt,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    dilation=dilation,
 | 
						|
                    groups=groups,
 | 
						|
                )
 | 
						|
                out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
 | 
						|
 | 
						|
                self.assertEqual(out_pt.shape, out_mx.shape)
 | 
						|
                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | 
						|
 | 
						|
        for dtype in ("float32",):
 | 
						|
            for N, C, O in (
 | 
						|
                (1, 1, 1),
 | 
						|
                (1, 6, 1),
 | 
						|
                (1, 1, 6),
 | 
						|
                (4, 32, 64),
 | 
						|
            ):
 | 
						|
                for idim, kdim, stride, padding in (
 | 
						|
                    ((1, 1), (1, 1), (1, 1), (0, 0)),
 | 
						|
                    ((3, 3), (3, 1), (1, 1), (0, 0)),
 | 
						|
                    ((31, 31), (5, 5), (5, 5), (2, 2)),
 | 
						|
                ):
 | 
						|
                    run_conv_transpose2D(
 | 
						|
                        N, C, O, idim, kdim, stride, padding, dtype=dtype
 | 
						|
                    )
 | 
						|
 | 
						|
            # Groups tests
 | 
						|
            N, C, O = (4, 32, 64)
 | 
						|
            for idim, kdim, stride, padding in (
 | 
						|
                ((1, 1), (1, 1), (1, 1), (0, 0)),
 | 
						|
                ((3, 3), (3, 1), (1, 1), (0, 0)),
 | 
						|
                ((31, 31), (5, 5), (5, 5), (2, 2)),
 | 
						|
            ):
 | 
						|
                for group in (1,):
 | 
						|
                    run_conv_transpose2D(
 | 
						|
                        N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
 | 
						|
                    )
 | 
						|
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_torch_conv_transpose_2D_grad(self):
 | 
						|
        def run_conv_transpose2D_grad(
 | 
						|
            N,
 | 
						|
            C,
 | 
						|
            O,
 | 
						|
            idim,
 | 
						|
            kdim,
 | 
						|
            stride,
 | 
						|
            padding,
 | 
						|
            dilation=(1, 1),
 | 
						|
            groups=1,
 | 
						|
            dtype="float32",
 | 
						|
            atol=1e-5,
 | 
						|
        ):
 | 
						|
            with self.subTest(
 | 
						|
                dtype=dtype,
 | 
						|
                N=N,
 | 
						|
                C=C,
 | 
						|
                O=O,
 | 
						|
                idim=idim,
 | 
						|
                kdim=kdim,
 | 
						|
                stride=stride,
 | 
						|
                padding=padding,
 | 
						|
                dilation=dilation,
 | 
						|
                groups=groups,
 | 
						|
            ):
 | 
						|
                np_dtype = getattr(np, dtype)
 | 
						|
                np.random.seed(0)
 | 
						|
                iH, iW = idim
 | 
						|
                kH, kW = kdim
 | 
						|
                scale = 1.0 / math.sqrt(kH * kW * C * O)
 | 
						|
 | 
						|
                in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
 | 
						|
                wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_(
 | 
						|
                    True
 | 
						|
                )
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_(
 | 
						|
                    True
 | 
						|
                )
 | 
						|
 | 
						|
                out_pt = F.conv_transpose2d(
 | 
						|
                    in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
 | 
						|
                )
 | 
						|
 | 
						|
                # use torch to compute ct
 | 
						|
                out_pt.retain_grad()
 | 
						|
                out_pt.sum().backward()
 | 
						|
 | 
						|
                pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
 | 
						|
                pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
 | 
						|
 | 
						|
                ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1))
 | 
						|
 | 
						|
                def f(a, b):
 | 
						|
                    return mx.conv_transpose2d(
 | 
						|
                        a,
 | 
						|
                        b,
 | 
						|
                        stride=stride,
 | 
						|
                        padding=padding,
 | 
						|
                        dilation=dilation,
 | 
						|
                        groups=groups,
 | 
						|
                    )
 | 
						|
 | 
						|
                _, outs_mx = mx.vjp(
 | 
						|
                    f,
 | 
						|
                    [in_mx, wt_mx],
 | 
						|
                    [ct_mx],
 | 
						|
                )
 | 
						|
 | 
						|
                mx_grad_in, mx_grad_wt = outs_mx
 | 
						|
 | 
						|
                self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
 | 
						|
                self.assertEqual(in_mx.shape, mx_grad_in.shape)
 | 
						|
                self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
 | 
						|
 | 
						|
                self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
 | 
						|
                self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
 | 
						|
                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
 | 
						|
 | 
						|
        for dtype in ("float32",):
 | 
						|
            for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
 | 
						|
                for idim, kdim, stride, padding, dilation in (
 | 
						|
                    ((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
 | 
						|
                    ((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
 | 
						|
                    ((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
 | 
						|
                    ((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
 | 
						|
                    ((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
 | 
						|
                    ((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
 | 
						|
                ):
 | 
						|
                    run_conv_transpose2D_grad(
 | 
						|
                        N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
 | 
						|
                    )
 | 
						|
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_torch_conv_transpose_3D(self):
 | 
						|
        def run_conv_transpose3D(
 | 
						|
            N,
 | 
						|
            C,
 | 
						|
            O,
 | 
						|
            idim,
 | 
						|
            kdim,
 | 
						|
            stride,
 | 
						|
            padding,
 | 
						|
            dilation=(1, 1, 1),
 | 
						|
            groups=1,
 | 
						|
            dtype="float32",
 | 
						|
            atol=1e-5,
 | 
						|
        ):
 | 
						|
            with self.subTest(
 | 
						|
                dtype=dtype,
 | 
						|
                N=N,
 | 
						|
                C=C,
 | 
						|
                O=O,
 | 
						|
                idim=idim,
 | 
						|
                kdim=kdim,
 | 
						|
                stride=stride,
 | 
						|
                padding=padding,
 | 
						|
                dilation=dilation,
 | 
						|
                groups=groups,
 | 
						|
            ):
 | 
						|
                np_dtype = getattr(np, dtype)
 | 
						|
                np.random.seed(0)
 | 
						|
                iD, iH, iW = idim
 | 
						|
                kD, kH, kW = kdim
 | 
						|
                scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
 | 
						|
                in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
 | 
						|
                    np_dtype
 | 
						|
                )
 | 
						|
                wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
 | 
						|
 | 
						|
                out_mx = mx.conv_transpose3d(
 | 
						|
                    in_mx,
 | 
						|
                    wt_mx,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    dilation=dilation,
 | 
						|
                    groups=groups,
 | 
						|
                )
 | 
						|
                out_pt = torch.conv_transpose3d(
 | 
						|
                    in_pt,
 | 
						|
                    wt_pt,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    dilation=dilation,
 | 
						|
                    groups=groups,
 | 
						|
                )
 | 
						|
                out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
 | 
						|
 | 
						|
                self.assertEqual(out_pt.shape, out_mx.shape)
 | 
						|
                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | 
						|
 | 
						|
        for dtype in ("float32",):
 | 
						|
            for N, C, O in (
 | 
						|
                (1, 1, 1),
 | 
						|
                (1, 6, 1),
 | 
						|
                (1, 1, 6),
 | 
						|
                (2, 8, 16),
 | 
						|
            ):
 | 
						|
                for idim, kdim, stride, padding in (
 | 
						|
                    ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
 | 
						|
                    ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
 | 
						|
                    ((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)),
 | 
						|
                ):
 | 
						|
                    run_conv_transpose3D(
 | 
						|
                        N, C, O, idim, kdim, stride, padding, dtype=dtype
 | 
						|
                    )
 | 
						|
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_torch_conv_transpose_3D_grad(self):
 | 
						|
        def run_conv_transpose3D_grad(
 | 
						|
            N,
 | 
						|
            C,
 | 
						|
            O,
 | 
						|
            idim,
 | 
						|
            kdim,
 | 
						|
            stride,
 | 
						|
            padding,
 | 
						|
            dilation=(1, 1, 1),
 | 
						|
            groups=1,
 | 
						|
            dtype="float32",
 | 
						|
            atol=1e-4,
 | 
						|
        ):
 | 
						|
            with self.subTest(
 | 
						|
                dtype=dtype,
 | 
						|
                N=N,
 | 
						|
                C=C,
 | 
						|
                O=O,
 | 
						|
                idim=idim,
 | 
						|
                kdim=kdim,
 | 
						|
                stride=stride,
 | 
						|
                padding=padding,
 | 
						|
                dilation=dilation,
 | 
						|
                groups=groups,
 | 
						|
            ):
 | 
						|
                np_dtype = getattr(np, dtype)
 | 
						|
                np.random.seed(0)
 | 
						|
                iD, iH, iW = idim
 | 
						|
                kD, kH, kW = kdim
 | 
						|
                scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
 | 
						|
 | 
						|
                in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
 | 
						|
                    np_dtype
 | 
						|
                )
 | 
						|
                wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
 | 
						|
                    np_dtype
 | 
						|
                )
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_(
 | 
						|
                    True
 | 
						|
                )
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_(
 | 
						|
                    True
 | 
						|
                )
 | 
						|
 | 
						|
                out_pt = F.conv_transpose3d(
 | 
						|
                    in_pt,
 | 
						|
                    wt_pt,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    dilation=dilation,
 | 
						|
                    groups=groups,
 | 
						|
                )
 | 
						|
 | 
						|
                # use torch to compute ct
 | 
						|
                out_pt.retain_grad()
 | 
						|
                out_pt.sum().backward()
 | 
						|
 | 
						|
                pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
 | 
						|
                pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
 | 
						|
 | 
						|
                ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1))
 | 
						|
 | 
						|
                def f(a, b):
 | 
						|
                    return mx.conv_transpose3d(
 | 
						|
                        a,
 | 
						|
                        b,
 | 
						|
                        stride=stride,
 | 
						|
                        padding=padding,
 | 
						|
                        dilation=dilation,
 | 
						|
                        groups=groups,
 | 
						|
                    )
 | 
						|
 | 
						|
                _, outs_mx = mx.vjp(
 | 
						|
                    f,
 | 
						|
                    [in_mx, wt_mx],
 | 
						|
                    [ct_mx],
 | 
						|
                )
 | 
						|
 | 
						|
                mx_grad_in, mx_grad_wt = outs_mx
 | 
						|
 | 
						|
                self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
 | 
						|
                self.assertEqual(in_mx.shape, mx_grad_in.shape)
 | 
						|
                self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
 | 
						|
 | 
						|
                self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
 | 
						|
                self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
 | 
						|
                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
 | 
						|
 | 
						|
        for dtype in ("float32",):
 | 
						|
            for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)):
 | 
						|
                for idim, kdim, stride, padding, dilation in (
 | 
						|
                    ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
 | 
						|
                    ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
 | 
						|
                    ((7, 7, 7), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
 | 
						|
                    ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
 | 
						|
                    ((7, 7, 7), (5, 5, 5), (3, 3, 3), (2, 2, 2), (3, 2, 2)),
 | 
						|
                    ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
 | 
						|
                ):
 | 
						|
                    run_conv_transpose3D_grad(
 | 
						|
                        N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
 | 
						|
                    )
 | 
						|
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_torch_conv_tranpose_1d_output_padding(self):
 | 
						|
        def run_conv_transpose_1d_output_padding(
 | 
						|
            N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5
 | 
						|
        ):
 | 
						|
            with self.subTest(
 | 
						|
                dtype=dtype,
 | 
						|
                N=N,
 | 
						|
                C=C,
 | 
						|
                O=O,
 | 
						|
                iH=iH,
 | 
						|
                kH=kH,
 | 
						|
                stride=stride,
 | 
						|
                padding=padding,
 | 
						|
                output_padding=output_padding,
 | 
						|
            ):
 | 
						|
                np_dtype = getattr(np, dtype)
 | 
						|
                np.random.seed(0)
 | 
						|
                in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
 | 
						|
                wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
 | 
						|
 | 
						|
                out_mx = mx.conv_transpose1d(
 | 
						|
                    in_mx,
 | 
						|
                    wt_mx,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    output_padding=output_padding,
 | 
						|
                )
 | 
						|
 | 
						|
                out_pt = torch.conv_transpose1d(
 | 
						|
                    in_pt,
 | 
						|
                    wt_pt,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    output_padding=output_padding,
 | 
						|
                )
 | 
						|
                out_pt = torch.transpose(out_pt, 2, 1)
 | 
						|
 | 
						|
                self.assertEqual(out_pt.shape, out_mx.shape)
 | 
						|
                self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
 | 
						|
 | 
						|
        for dtype in ("float32",):
 | 
						|
            for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
 | 
						|
                for iH, kH, stride, padding, output_padding in (
 | 
						|
                    (3, 2, 2, 0, 1),
 | 
						|
                    (5, 3, 2, 1, 0),
 | 
						|
                    (7, 4, 3, 1, 2),
 | 
						|
                ):
 | 
						|
                    run_conv_transpose_1d_output_padding(
 | 
						|
                        N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype
 | 
						|
                    )
 | 
						|
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_torch_conv_transpose_2d_output_padding(self):
 | 
						|
        def run_conv_transpose_2d_output_padding(
 | 
						|
            N,
 | 
						|
            C,
 | 
						|
            O,
 | 
						|
            idim,
 | 
						|
            kdim,
 | 
						|
            stride,
 | 
						|
            padding,
 | 
						|
            output_padding,
 | 
						|
            dtype="float32",
 | 
						|
            atol=1e-5,
 | 
						|
        ):
 | 
						|
            with self.subTest(
 | 
						|
                dtype=dtype,
 | 
						|
                N=N,
 | 
						|
                C=C,
 | 
						|
                O=O,
 | 
						|
                idim=idim,
 | 
						|
                kdim=kdim,
 | 
						|
                stride=stride,
 | 
						|
                padding=padding,
 | 
						|
                output_padding=output_padding,
 | 
						|
            ):
 | 
						|
                np_dtype = getattr(np, dtype)
 | 
						|
                np.random.seed(0)
 | 
						|
                iH, iW = idim
 | 
						|
                kH, kW = kdim
 | 
						|
                in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype)
 | 
						|
                wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype)
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2))
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2))
 | 
						|
 | 
						|
                out_mx = mx.conv_transpose2d(
 | 
						|
                    in_mx,
 | 
						|
                    wt_mx,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    output_padding=output_padding,
 | 
						|
                )
 | 
						|
 | 
						|
                out_pt = torch.conv_transpose2d(
 | 
						|
                    in_pt,
 | 
						|
                    wt_pt,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    output_padding=output_padding,
 | 
						|
                )
 | 
						|
                out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
 | 
						|
 | 
						|
                self.assertEqual(out_pt.shape, out_mx.shape)
 | 
						|
                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | 
						|
 | 
						|
        for dtype in ("float32",):
 | 
						|
            for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
 | 
						|
                for idim, kdim, stride, padding, output_padding in (
 | 
						|
                    ((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)),
 | 
						|
                    ((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)),
 | 
						|
                    ((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)),
 | 
						|
                ):
 | 
						|
                    run_conv_transpose_2d_output_padding(
 | 
						|
                        N,
 | 
						|
                        C,
 | 
						|
                        O,
 | 
						|
                        idim,
 | 
						|
                        kdim,
 | 
						|
                        stride,
 | 
						|
                        padding,
 | 
						|
                        output_padding,
 | 
						|
                        dtype=dtype,
 | 
						|
                    )
 | 
						|
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_torch_conv_transpose_3d_output_padding(self):
 | 
						|
        def run_conv_transpose_3d_output_padding(
 | 
						|
            N,
 | 
						|
            C,
 | 
						|
            O,
 | 
						|
            idim,
 | 
						|
            kdim,
 | 
						|
            stride,
 | 
						|
            padding,
 | 
						|
            output_padding,
 | 
						|
            dtype="float32",
 | 
						|
            atol=1e-5,
 | 
						|
        ):
 | 
						|
            with self.subTest(
 | 
						|
                dtype=dtype,
 | 
						|
                N=N,
 | 
						|
                C=C,
 | 
						|
                O=O,
 | 
						|
                idim=idim,
 | 
						|
                kdim=kdim,
 | 
						|
                stride=stride,
 | 
						|
                padding=padding,
 | 
						|
                output_padding=output_padding,
 | 
						|
            ):
 | 
						|
                np_dtype = getattr(np, dtype)
 | 
						|
                np.random.seed(0)
 | 
						|
                iD, iH, iW = idim
 | 
						|
                kD, kH, kW = kdim
 | 
						|
                in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype(
 | 
						|
                    np_dtype
 | 
						|
                )
 | 
						|
                wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype(
 | 
						|
                    np_dtype
 | 
						|
                )
 | 
						|
 | 
						|
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
						|
                in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
 | 
						|
                wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
 | 
						|
 | 
						|
                out_mx = mx.conv_transpose3d(
 | 
						|
                    in_mx,
 | 
						|
                    wt_mx,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    output_padding=output_padding,
 | 
						|
                )
 | 
						|
                out_pt = torch.conv_transpose3d(
 | 
						|
                    in_pt,
 | 
						|
                    wt_pt,
 | 
						|
                    stride=stride,
 | 
						|
                    padding=padding,
 | 
						|
                    output_padding=output_padding,
 | 
						|
                )
 | 
						|
                out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
 | 
						|
 | 
						|
                self.assertEqual(out_pt.shape, out_mx.shape)
 | 
						|
                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | 
						|
 | 
						|
        for dtype in ("float32",):
 | 
						|
            for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
 | 
						|
                for idim, kdim, stride, padding, output_padding in (
 | 
						|
                    ((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)),
 | 
						|
                    ((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)),
 | 
						|
                    ((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)),
 | 
						|
                ):
 | 
						|
                    run_conv_transpose_3d_output_padding(
 | 
						|
                        N,
 | 
						|
                        C,
 | 
						|
                        O,
 | 
						|
                        idim,
 | 
						|
                        kdim,
 | 
						|
                        stride,
 | 
						|
                        padding,
 | 
						|
                        output_padding,
 | 
						|
                        dtype=dtype,
 | 
						|
                    )
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    mlx_tests.MLXTestRunner()
 |