add median op (#2705)

This commit is contained in:
Awni Hannun
2025-10-27 11:33:42 -07:00
committed by GitHub
parent c4767d110f
commit 539d8322d1
5 changed files with 164 additions and 0 deletions

View File

@@ -1932,6 +1932,87 @@ array mean(
return mean(a, std::vector<int>{axis}, keepdims, to_stream(s));
}
array median(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return median(a, axes, keepdims, to_stream(s));
}
array median(
const array& a,
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
int ndim = a.ndim();
std::set<int> set_axes;
for (int axis : axes) {
if (axis < -ndim || axis >= ndim) {
std::ostringstream msg;
msg << "[median] axis " << axis << " is out of bounds for array with "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
set_axes.insert(axis < 0 ? axis + ndim : axis);
}
if (set_axes.size() != axes.size()) {
throw std::invalid_argument("[median] Received duplicate axis.");
}
std::vector<int> sorted_axes(set_axes.begin(), set_axes.end());
auto dtype = at_least_float(a.dtype());
std::vector<int> transpose_axes;
for (int i = 0, j = 0; i < a.ndim(); ++i) {
if (j < sorted_axes.size() && i == sorted_axes[j]) {
j++;
continue;
}
transpose_axes.push_back(i);
}
int flat_start = transpose_axes.size();
transpose_axes.insert(
transpose_axes.end(), sorted_axes.begin(), sorted_axes.end());
// Move all the median axes to the back and flatten
auto flat_a =
flatten(transpose(a, transpose_axes, s), flat_start, a.ndim(), s);
int flat_size = flat_a.shape(-1);
if (flat_size == 0) {
throw std::invalid_argument(
"[median] Cannot take median along empty axis.");
}
// Sort the last axis
auto sorted_a = sort(flat_a, -1, s);
// Take the midpoint
auto mp = flat_size / 2;
auto start = Shape(sorted_a.ndim(), 0);
auto stop = sorted_a.shape();
start.back() = mp;
stop.back() = mp + 1;
auto median_a = astype(slice(sorted_a, start, stop, s), dtype, s);
if (flat_size % 2 == 0) {
start.back() = mp - 1;
stop.back() = mp;
median_a = multiply(
add(median_a, astype(slice(sorted_a, start, stop, s), dtype, s), s),
array(0.5, dtype),
s);
}
median_a = squeeze(median_a, -1, s);
if (keepdims) {
median_a = expand_dims(median_a, sorted_axes, s);
}
return median_a;
}
array median(
const array& a,
int axis,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
return median(a, std::vector<int>{axis}, keepdims, to_stream(s));
}
array var(
const array& a,
bool keepdims,

View File

@@ -539,6 +539,26 @@ array mean(
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the median of the elements of an array. */
array median(const array& a, bool keepdims, StreamOrDevice s = {});
inline array median(const array& a, StreamOrDevice s = {}) {
return median(a, false, to_stream(s));
}
/** Computes the median of the elements of an array along the given axes */
array median(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the median of the elements of an array along the given axis */
array median(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the variance of the elements of an array. */
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array var(const array& a, StreamOrDevice s = {}) {