Fix cpu segfault (#1488)

* fix cpu segfault

* nit in tests
This commit is contained in:
Awni Hannun
2024-10-14 16:17:03 -07:00
committed by GitHub
parent 020f048cd0
commit 0ab8e099e8
6 changed files with 52 additions and 51 deletions

View File

@@ -171,7 +171,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
@@ -365,7 +365,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
@@ -549,7 +549,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()