mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add median op (#2705)
This commit is contained in:
81
mlx/ops.cpp
81
mlx/ops.cpp
@@ -1932,6 +1932,87 @@ array mean(
|
||||
return mean(a, std::vector<int>{axis}, keepdims, to_stream(s));
|
||||
}
|
||||
|
||||
array median(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return median(a, axes, keepdims, to_stream(s));
|
||||
}
|
||||
|
||||
array median(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
int ndim = a.ndim();
|
||||
std::set<int> set_axes;
|
||||
for (int axis : axes) {
|
||||
if (axis < -ndim || axis >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[median] axis " << axis << " is out of bounds for array with "
|
||||
<< ndim << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
set_axes.insert(axis < 0 ? axis + ndim : axis);
|
||||
}
|
||||
if (set_axes.size() != axes.size()) {
|
||||
throw std::invalid_argument("[median] Received duplicate axis.");
|
||||
}
|
||||
std::vector<int> sorted_axes(set_axes.begin(), set_axes.end());
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
std::vector<int> transpose_axes;
|
||||
for (int i = 0, j = 0; i < a.ndim(); ++i) {
|
||||
if (j < sorted_axes.size() && i == sorted_axes[j]) {
|
||||
j++;
|
||||
continue;
|
||||
}
|
||||
transpose_axes.push_back(i);
|
||||
}
|
||||
int flat_start = transpose_axes.size();
|
||||
transpose_axes.insert(
|
||||
transpose_axes.end(), sorted_axes.begin(), sorted_axes.end());
|
||||
|
||||
// Move all the median axes to the back and flatten
|
||||
auto flat_a =
|
||||
flatten(transpose(a, transpose_axes, s), flat_start, a.ndim(), s);
|
||||
int flat_size = flat_a.shape(-1);
|
||||
if (flat_size == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[median] Cannot take median along empty axis.");
|
||||
}
|
||||
|
||||
// Sort the last axis
|
||||
auto sorted_a = sort(flat_a, -1, s);
|
||||
|
||||
// Take the midpoint
|
||||
auto mp = flat_size / 2;
|
||||
auto start = Shape(sorted_a.ndim(), 0);
|
||||
auto stop = sorted_a.shape();
|
||||
start.back() = mp;
|
||||
stop.back() = mp + 1;
|
||||
auto median_a = astype(slice(sorted_a, start, stop, s), dtype, s);
|
||||
if (flat_size % 2 == 0) {
|
||||
start.back() = mp - 1;
|
||||
stop.back() = mp;
|
||||
median_a = multiply(
|
||||
add(median_a, astype(slice(sorted_a, start, stop, s), dtype, s), s),
|
||||
array(0.5, dtype),
|
||||
s);
|
||||
}
|
||||
median_a = squeeze(median_a, -1, s);
|
||||
if (keepdims) {
|
||||
median_a = expand_dims(median_a, sorted_axes, s);
|
||||
}
|
||||
return median_a;
|
||||
}
|
||||
|
||||
array median(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return median(a, std::vector<int>{axis}, keepdims, to_stream(s));
|
||||
}
|
||||
|
||||
array var(
|
||||
const array& a,
|
||||
bool keepdims,
|
||||
|
||||
20
mlx/ops.h
20
mlx/ops.h
@@ -539,6 +539,26 @@ array mean(
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the median of the elements of an array. */
|
||||
array median(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||
inline array median(const array& a, StreamOrDevice s = {}) {
|
||||
return median(a, false, to_stream(s));
|
||||
}
|
||||
|
||||
/** Computes the median of the elements of an array along the given axes */
|
||||
array median(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the median of the elements of an array along the given axis */
|
||||
array median(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the variance of the elements of an array. */
|
||||
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||
inline array var(const array& a, StreamOrDevice s = {}) {
|
||||
|
||||
Reference in New Issue
Block a user