mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 22:44:38 +08:00
Make shape a tuple (#591)
* shape tuple * also remove simplify from docs * rebase
This commit is contained in:
@@ -44,7 +44,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
c_np = np.convolve(a_np, v_np, mode=mode)
|
||||
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
|
||||
|
||||
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
|
||||
self.assertEqual(c_mx.shape, c_np.shape)
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
@@ -102,7 +102,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
@@ -141,7 +141,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
out_pt = torch.conv1d(in_pt, wt_pt)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
@@ -228,12 +228,12 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
@@ -309,7 +309,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
@@ -419,12 +419,12 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
|
Reference in New Issue
Block a user