mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fix logsumexp edge case (#740)
* fix logsumexp * fix inf constant * also fix power grad * fix ternary dispatch
This commit is contained in:
		| @@ -415,6 +415,14 @@ class TestAutograd(mlx_tests.MLXTestCase): | ||||
|         _, vjps = mx.vjp(func, (arr,), (cotan,)) | ||||
|         self.assertEqual(vjps[0].item(), 8.0) | ||||
|  | ||||
|     def test_power_grad(self): | ||||
|         def fun(x, y): | ||||
|             res = x - y | ||||
|             return res**x | ||||
|  | ||||
|         grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0)) | ||||
|         self.assertEqual(grad.item(), 1.0) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -539,6 +539,15 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         z = fun(mx.array(1), "two") | ||||
|         self.assertEqual(z.item(), 3) | ||||
|  | ||||
|     def test_compile_inf(self): | ||||
|  | ||||
|         @mx.compile | ||||
|         def fun(x): | ||||
|             return mx.isinf(x + 2) | ||||
|  | ||||
|         out = fun(mx.array([0.0])) | ||||
|         self.assertEqual(out.item(), False) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -66,13 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|     def test_save_and_load_safetensors(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
|         test_file = os.path.join(self.test_dir, "test.safetensors") | ||||
|         with self.assertRaises(Exception): | ||||
|             mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0}) | ||||
|             mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0}) | ||||
|  | ||||
|         mx.save_safetensors( | ||||
|             "test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} | ||||
|             test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} | ||||
|         ) | ||||
|         res = mx.load("test.safetensors", return_metadata=True) | ||||
|         res = mx.load(test_file, return_metadata=True) | ||||
|         self.assertEqual(len(res), 2) | ||||
|         self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun