mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 20:20:11 +08:00
add median op (#2705)
This commit is contained in:
@@ -775,6 +775,39 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
|
||||
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
|
||||
|
||||
def test_median(self):
|
||||
x = mx.array([])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.median(x, axis=0)
|
||||
x = mx.array([0, 1, 2, 3, 4])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.median(x, axis=(0, 1))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.median(x, axis=(0, 0))
|
||||
|
||||
out = mx.median(x)
|
||||
self.assertEqual(out.shape, ())
|
||||
self.assertEqual(out.item(), 2)
|
||||
out = mx.median(x, keepdims=True)
|
||||
self.assertEqual(out.shape, (1,))
|
||||
|
||||
x = mx.array([0, 1, 2, 3, 4, 5])
|
||||
out = mx.median(x)
|
||||
self.assertEqual(out.item(), 2.5)
|
||||
|
||||
x = mx.random.normal((5, 5, 5, 5))
|
||||
out = mx.median(x, axis=(0, 2), keepdims=True)
|
||||
out_np = np.median(x, axis=(0, 2), keepdims=True)
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
out = mx.median(x, axis=(1, 3), keepdims=True)
|
||||
out_np = np.median(x, axis=(1, 3), keepdims=True)
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
out = mx.median(x, axis=(0, 1, 3), keepdims=True)
|
||||
out_np = np.median(x, axis=(0, 1, 3), keepdims=True)
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
def test_var(self):
|
||||
x = mx.array(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user