mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	std and expm1 (#973)
* std and expm1 * actually add expm1 * fix linux * fix vjp * relax tol for linux test * Add it to the compilable primitives --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		@@ -725,6 +725,11 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
        out = mx.var(x, ddof=3)
 | 
			
		||||
        self.assertEqual(out.item(), float("inf"))
 | 
			
		||||
 | 
			
		||||
    def test_std(self):
 | 
			
		||||
        x = mx.random.uniform(shape=(5, 5))
 | 
			
		||||
        x_np = np.array(x)
 | 
			
		||||
        self.assertAlmostEqual(mx.std(x).item(), x_np.std().item(), places=6)
 | 
			
		||||
 | 
			
		||||
    def test_abs(self):
 | 
			
		||||
        a = mx.array([-1.0, 1.0, -2.0, 3.0])
 | 
			
		||||
        result = mx.abs(a)
 | 
			
		||||
@@ -839,6 +844,13 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(np.allclose(result, expected))
 | 
			
		||||
 | 
			
		||||
    def test_expm1(self):
 | 
			
		||||
        a = mx.array([0, 0.5, -0.5, 5])
 | 
			
		||||
        result = mx.expm1(a)
 | 
			
		||||
        expected = np.expm1(a, dtype=np.float32)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5))
 | 
			
		||||
 | 
			
		||||
    def test_erf(self):
 | 
			
		||||
        inputs = [-5, 0.0, 0.5, 1.0, 2.0, 10.0]
 | 
			
		||||
        x = mx.array(inputs)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user