add median op (#2705)

This commit is contained in:
Awni Hannun
2025-10-27 11:33:42 -07:00
committed by GitHub
parent c4767d110f
commit 539d8322d1
5 changed files with 164 additions and 0 deletions

View File

@@ -112,6 +112,7 @@ Operations
max
maximum
mean
median
meshgrid
min
minimum

View File

@@ -1932,6 +1932,87 @@ array mean(
return mean(a, std::vector<int>{axis}, keepdims, to_stream(s));
}
array median(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return median(a, axes, keepdims, to_stream(s));
}
array median(
const array& a,
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
int ndim = a.ndim();
std::set<int> set_axes;
for (int axis : axes) {
if (axis < -ndim || axis >= ndim) {
std::ostringstream msg;
msg << "[median] axis " << axis << " is out of bounds for array with "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
set_axes.insert(axis < 0 ? axis + ndim : axis);
}
if (set_axes.size() != axes.size()) {
throw std::invalid_argument("[median] Received duplicate axis.");
}
std::vector<int> sorted_axes(set_axes.begin(), set_axes.end());
auto dtype = at_least_float(a.dtype());
std::vector<int> transpose_axes;
for (int i = 0, j = 0; i < a.ndim(); ++i) {
if (j < sorted_axes.size() && i == sorted_axes[j]) {
j++;
continue;
}
transpose_axes.push_back(i);
}
int flat_start = transpose_axes.size();
transpose_axes.insert(
transpose_axes.end(), sorted_axes.begin(), sorted_axes.end());
// Move all the median axes to the back and flatten
auto flat_a =
flatten(transpose(a, transpose_axes, s), flat_start, a.ndim(), s);
int flat_size = flat_a.shape(-1);
if (flat_size == 0) {
throw std::invalid_argument(
"[median] Cannot take median along empty axis.");
}
// Sort the last axis
auto sorted_a = sort(flat_a, -1, s);
// Take the midpoint
auto mp = flat_size / 2;
auto start = Shape(sorted_a.ndim(), 0);
auto stop = sorted_a.shape();
start.back() = mp;
stop.back() = mp + 1;
auto median_a = astype(slice(sorted_a, start, stop, s), dtype, s);
if (flat_size % 2 == 0) {
start.back() = mp - 1;
stop.back() = mp;
median_a = multiply(
add(median_a, astype(slice(sorted_a, start, stop, s), dtype, s), s),
array(0.5, dtype),
s);
}
median_a = squeeze(median_a, -1, s);
if (keepdims) {
median_a = expand_dims(median_a, sorted_axes, s);
}
return median_a;
}
array median(
const array& a,
int axis,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
return median(a, std::vector<int>{axis}, keepdims, to_stream(s));
}
array var(
const array& a,
bool keepdims,

View File

@@ -539,6 +539,26 @@ array mean(
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the median of the elements of an array. */
array median(const array& a, bool keepdims, StreamOrDevice s = {});
inline array median(const array& a, StreamOrDevice s = {}) {
return median(a, false, to_stream(s));
}
/** Computes the median of the elements of an array along the given axes */
array median(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the median of the elements of an array along the given axis */
array median(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the variance of the elements of an array. */
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array var(const array& a, StreamOrDevice s = {}) {

View File

@@ -2484,6 +2484,35 @@ void init_ops(nb::module_& m) {
Returns:
array: The output array of means.
)pbdoc");
m.def(
"median",
[](const mx::array& a,
const IntOrVec& axis,
bool keepdims,
mx::StreamOrDevice s) {
return mx::median(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
},
nb::arg(),
"axis"_a = nb::none(),
"keepdims"_a = false,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def median(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the median(s) over the given axes.
Args:
a (array): Input array.
axis (int or list(int), optional): Optional axis or
axes to reduce over. If unspecified this defaults
to reducing over the entire array.
keepdims (bool, optional): Keep reduced axes as
singleton dimensions, defaults to `False`.
Returns:
array: The output array of medians.
)pbdoc");
m.def(
"var",
[](const mx::array& a,

View File

@@ -775,6 +775,39 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
def test_median(self):
x = mx.array([])
with self.assertRaises(ValueError):
mx.median(x, axis=0)
x = mx.array([0, 1, 2, 3, 4])
with self.assertRaises(ValueError):
mx.median(x, axis=(0, 1))
with self.assertRaises(ValueError):
mx.median(x, axis=(0, 0))
out = mx.median(x)
self.assertEqual(out.shape, ())
self.assertEqual(out.item(), 2)
out = mx.median(x, keepdims=True)
self.assertEqual(out.shape, (1,))
x = mx.array([0, 1, 2, 3, 4, 5])
out = mx.median(x)
self.assertEqual(out.item(), 2.5)
x = mx.random.normal((5, 5, 5, 5))
out = mx.median(x, axis=(0, 2), keepdims=True)
out_np = np.median(x, axis=(0, 2), keepdims=True)
self.assertTrue(np.allclose(out, out_np))
out = mx.median(x, axis=(1, 3), keepdims=True)
out_np = np.median(x, axis=(1, 3), keepdims=True)
self.assertTrue(np.allclose(out, out_np))
out = mx.median(x, axis=(0, 1, 3), keepdims=True)
out_np = np.median(x, axis=(0, 1, 3), keepdims=True)
self.assertTrue(np.allclose(out, out_np))
def test_var(self):
x = mx.array(
[