Make shape a tuple (#591)

* shape tuple

* also remove simplify from docs

* rebase
This commit is contained in:
Awni Hannun
2024-01-30 13:11:01 -08:00
committed by GitHub
parent d3a9005454
commit 09b9275027
13 changed files with 141 additions and 140 deletions

View File

@@ -44,7 +44,7 @@ class TestConv(mlx_tests.MLXTestCase):
c_np = np.convolve(a_np, v_np, mode=mode)
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
self.assertEqual(c_mx.shape, c_np.shape)
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
@unittest.skipIf(not has_torch, "requires Torch")
@@ -102,7 +102,7 @@ class TestConv(mlx_tests.MLXTestCase):
)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",):
@@ -141,7 +141,7 @@ class TestConv(mlx_tests.MLXTestCase):
out_pt = torch.conv1d(in_pt, wt_pt)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
@unittest.skipIf(not has_torch, "requires Torch")
@@ -228,12 +228,12 @@ class TestConv(mlx_tests.MLXTestCase):
mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
@@ -309,7 +309,7 @@ class TestConv(mlx_tests.MLXTestCase):
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
@@ -419,12 +419,12 @@ class TestConv(mlx_tests.MLXTestCase):
mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):