add median op (#2705)

This commit is contained in:
Awni Hannun
2025-10-27 11:33:42 -07:00
committed by GitHub
parent c4767d110f
commit 539d8322d1
5 changed files with 164 additions and 0 deletions

View File

@@ -775,6 +775,39 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
def test_median(self):
x = mx.array([])
with self.assertRaises(ValueError):
mx.median(x, axis=0)
x = mx.array([0, 1, 2, 3, 4])
with self.assertRaises(ValueError):
mx.median(x, axis=(0, 1))
with self.assertRaises(ValueError):
mx.median(x, axis=(0, 0))
out = mx.median(x)
self.assertEqual(out.shape, ())
self.assertEqual(out.item(), 2)
out = mx.median(x, keepdims=True)
self.assertEqual(out.shape, (1,))
x = mx.array([0, 1, 2, 3, 4, 5])
out = mx.median(x)
self.assertEqual(out.item(), 2.5)
x = mx.random.normal((5, 5, 5, 5))
out = mx.median(x, axis=(0, 2), keepdims=True)
out_np = np.median(x, axis=(0, 2), keepdims=True)
self.assertTrue(np.allclose(out, out_np))
out = mx.median(x, axis=(1, 3), keepdims=True)
out_np = np.median(x, axis=(1, 3), keepdims=True)
self.assertTrue(np.allclose(out, out_np))
out = mx.median(x, axis=(0, 1, 3), keepdims=True)
out_np = np.median(x, axis=(0, 1, 3), keepdims=True)
self.assertTrue(np.allclose(out, out_np))
def test_var(self):
x = mx.array(
[