diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 55fc1f534..6aabfc695 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -112,6 +112,7 @@ Operations max maximum mean + median meshgrid min minimum diff --git a/mlx/ops.cpp b/mlx/ops.cpp index dc985ce95..30e934f82 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1932,6 +1932,87 @@ array mean( return mean(a, std::vector{axis}, keepdims, to_stream(s)); } +array median(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return median(a, axes, keepdims, to_stream(s)); +} + +array median( + const array& a, + const std::vector& axes, + bool keepdims /* = false */, + StreamOrDevice s /* = {}*/) { + int ndim = a.ndim(); + std::set set_axes; + for (int axis : axes) { + if (axis < -ndim || axis >= ndim) { + std::ostringstream msg; + msg << "[median] axis " << axis << " is out of bounds for array with " + << ndim << " dimensions."; + throw std::invalid_argument(msg.str()); + } + set_axes.insert(axis < 0 ? axis + ndim : axis); + } + if (set_axes.size() != axes.size()) { + throw std::invalid_argument("[median] Received duplicate axis."); + } + std::vector sorted_axes(set_axes.begin(), set_axes.end()); + auto dtype = at_least_float(a.dtype()); + std::vector transpose_axes; + for (int i = 0, j = 0; i < a.ndim(); ++i) { + if (j < sorted_axes.size() && i == sorted_axes[j]) { + j++; + continue; + } + transpose_axes.push_back(i); + } + int flat_start = transpose_axes.size(); + transpose_axes.insert( + transpose_axes.end(), sorted_axes.begin(), sorted_axes.end()); + + // Move all the median axes to the back and flatten + auto flat_a = + flatten(transpose(a, transpose_axes, s), flat_start, a.ndim(), s); + int flat_size = flat_a.shape(-1); + if (flat_size == 0) { + throw std::invalid_argument( + "[median] Cannot take median along empty axis."); + } + + // Sort the last axis + auto sorted_a = sort(flat_a, -1, s); + + // Take the midpoint + auto mp = flat_size / 2; + auto start = Shape(sorted_a.ndim(), 0); + auto stop = sorted_a.shape(); + start.back() = mp; + stop.back() = mp + 1; + auto median_a = astype(slice(sorted_a, start, stop, s), dtype, s); + if (flat_size % 2 == 0) { + start.back() = mp - 1; + stop.back() = mp; + median_a = multiply( + add(median_a, astype(slice(sorted_a, start, stop, s), dtype, s), s), + array(0.5, dtype), + s); + } + median_a = squeeze(median_a, -1, s); + if (keepdims) { + median_a = expand_dims(median_a, sorted_axes, s); + } + return median_a; +} + +array median( + const array& a, + int axis, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + return median(a, std::vector{axis}, keepdims, to_stream(s)); +} + array var( const array& a, bool keepdims, diff --git a/mlx/ops.h b/mlx/ops.h index cb3505df4..bfe6eff16 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -539,6 +539,26 @@ array mean( bool keepdims = false, StreamOrDevice s = {}); +/** Computes the median of the elements of an array. */ +array median(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array median(const array& a, StreamOrDevice s = {}) { + return median(a, false, to_stream(s)); +} + +/** Computes the median of the elements of an array along the given axes */ +array median( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the median of the elements of an array along the given axis */ +array median( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + /** Computes the variance of the elements of an array. */ array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {}); inline array var(const array& a, StreamOrDevice s = {}) { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2bf1a7ab1..2e364db76 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2484,6 +2484,35 @@ void init_ops(nb::module_& m) { Returns: array: The output array of means. )pbdoc"); + m.def( + "median", + [](const mx::array& a, + const IntOrVec& axis, + bool keepdims, + mx::StreamOrDevice s) { + return mx::median(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + nb::arg(), + "axis"_a = nb::none(), + "keepdims"_a = false, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def median(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the median(s) over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array of medians. + )pbdoc"); m.def( "var", [](const mx::array& a, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 30ba12417..8a353d743 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -775,6 +775,39 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3]) self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5]) + def test_median(self): + x = mx.array([]) + with self.assertRaises(ValueError): + mx.median(x, axis=0) + x = mx.array([0, 1, 2, 3, 4]) + with self.assertRaises(ValueError): + mx.median(x, axis=(0, 1)) + with self.assertRaises(ValueError): + mx.median(x, axis=(0, 0)) + + out = mx.median(x) + self.assertEqual(out.shape, ()) + self.assertEqual(out.item(), 2) + out = mx.median(x, keepdims=True) + self.assertEqual(out.shape, (1,)) + + x = mx.array([0, 1, 2, 3, 4, 5]) + out = mx.median(x) + self.assertEqual(out.item(), 2.5) + + x = mx.random.normal((5, 5, 5, 5)) + out = mx.median(x, axis=(0, 2), keepdims=True) + out_np = np.median(x, axis=(0, 2), keepdims=True) + self.assertTrue(np.allclose(out, out_np)) + + out = mx.median(x, axis=(1, 3), keepdims=True) + out_np = np.median(x, axis=(1, 3), keepdims=True) + self.assertTrue(np.allclose(out, out_np)) + + out = mx.median(x, axis=(0, 1, 3), keepdims=True) + out_np = np.median(x, axis=(0, 1, 3), keepdims=True) + self.assertTrue(np.allclose(out, out_np)) + def test_var(self): x = mx.array( [