mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	| @@ -171,7 +171,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase): | ||||
|  | ||||
|                 # use torch to compute ct | ||||
|                 out_pt.retain_grad() | ||||
|                 (out_pt - torch.randn_like(out_pt)).abs().sum().backward() | ||||
|                 out_pt.sum().backward() | ||||
|  | ||||
|                 pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy() | ||||
|                 pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy() | ||||
| @@ -365,7 +365,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase): | ||||
|  | ||||
|                 # use torch to compute ct | ||||
|                 out_pt.retain_grad() | ||||
|                 (out_pt - torch.randn_like(out_pt)).abs().sum().backward() | ||||
|                 out_pt.sum().backward() | ||||
|  | ||||
|                 pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy() | ||||
|                 pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy() | ||||
| @@ -549,7 +549,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase): | ||||
|  | ||||
|                 # use torch to compute ct | ||||
|                 out_pt.retain_grad() | ||||
|                 (out_pt - torch.randn_like(out_pt)).abs().sum().backward() | ||||
|                 out_pt.sum().backward() | ||||
|  | ||||
|                 pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy() | ||||
|                 pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun