mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 09:18:11 +08:00
add median op (#2705)
This commit is contained in:
@@ -112,6 +112,7 @@ Operations
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
median
|
||||
meshgrid
|
||||
min
|
||||
minimum
|
||||
|
||||
81
mlx/ops.cpp
81
mlx/ops.cpp
@@ -1932,6 +1932,87 @@ array mean(
|
||||
return mean(a, std::vector<int>{axis}, keepdims, to_stream(s));
|
||||
}
|
||||
|
||||
array median(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return median(a, axes, keepdims, to_stream(s));
|
||||
}
|
||||
|
||||
array median(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
int ndim = a.ndim();
|
||||
std::set<int> set_axes;
|
||||
for (int axis : axes) {
|
||||
if (axis < -ndim || axis >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[median] axis " << axis << " is out of bounds for array with "
|
||||
<< ndim << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
set_axes.insert(axis < 0 ? axis + ndim : axis);
|
||||
}
|
||||
if (set_axes.size() != axes.size()) {
|
||||
throw std::invalid_argument("[median] Received duplicate axis.");
|
||||
}
|
||||
std::vector<int> sorted_axes(set_axes.begin(), set_axes.end());
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
std::vector<int> transpose_axes;
|
||||
for (int i = 0, j = 0; i < a.ndim(); ++i) {
|
||||
if (j < sorted_axes.size() && i == sorted_axes[j]) {
|
||||
j++;
|
||||
continue;
|
||||
}
|
||||
transpose_axes.push_back(i);
|
||||
}
|
||||
int flat_start = transpose_axes.size();
|
||||
transpose_axes.insert(
|
||||
transpose_axes.end(), sorted_axes.begin(), sorted_axes.end());
|
||||
|
||||
// Move all the median axes to the back and flatten
|
||||
auto flat_a =
|
||||
flatten(transpose(a, transpose_axes, s), flat_start, a.ndim(), s);
|
||||
int flat_size = flat_a.shape(-1);
|
||||
if (flat_size == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[median] Cannot take median along empty axis.");
|
||||
}
|
||||
|
||||
// Sort the last axis
|
||||
auto sorted_a = sort(flat_a, -1, s);
|
||||
|
||||
// Take the midpoint
|
||||
auto mp = flat_size / 2;
|
||||
auto start = Shape(sorted_a.ndim(), 0);
|
||||
auto stop = sorted_a.shape();
|
||||
start.back() = mp;
|
||||
stop.back() = mp + 1;
|
||||
auto median_a = astype(slice(sorted_a, start, stop, s), dtype, s);
|
||||
if (flat_size % 2 == 0) {
|
||||
start.back() = mp - 1;
|
||||
stop.back() = mp;
|
||||
median_a = multiply(
|
||||
add(median_a, astype(slice(sorted_a, start, stop, s), dtype, s), s),
|
||||
array(0.5, dtype),
|
||||
s);
|
||||
}
|
||||
median_a = squeeze(median_a, -1, s);
|
||||
if (keepdims) {
|
||||
median_a = expand_dims(median_a, sorted_axes, s);
|
||||
}
|
||||
return median_a;
|
||||
}
|
||||
|
||||
array median(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return median(a, std::vector<int>{axis}, keepdims, to_stream(s));
|
||||
}
|
||||
|
||||
array var(
|
||||
const array& a,
|
||||
bool keepdims,
|
||||
|
||||
20
mlx/ops.h
20
mlx/ops.h
@@ -539,6 +539,26 @@ array mean(
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the median of the elements of an array. */
|
||||
array median(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||
inline array median(const array& a, StreamOrDevice s = {}) {
|
||||
return median(a, false, to_stream(s));
|
||||
}
|
||||
|
||||
/** Computes the median of the elements of an array along the given axes */
|
||||
array median(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the median of the elements of an array along the given axis */
|
||||
array median(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes the variance of the elements of an array. */
|
||||
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
||||
inline array var(const array& a, StreamOrDevice s = {}) {
|
||||
|
||||
@@ -2484,6 +2484,35 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The output array of means.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"median",
|
||||
[](const mx::array& a,
|
||||
const IntOrVec& axis,
|
||||
bool keepdims,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::median(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
||||
},
|
||||
nb::arg(),
|
||||
"axis"_a = nb::none(),
|
||||
"keepdims"_a = false,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def median(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the median(s) over the given axes.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
axis (int or list(int), optional): Optional axis or
|
||||
axes to reduce over. If unspecified this defaults
|
||||
to reducing over the entire array.
|
||||
keepdims (bool, optional): Keep reduced axes as
|
||||
singleton dimensions, defaults to `False`.
|
||||
|
||||
Returns:
|
||||
array: The output array of medians.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"var",
|
||||
[](const mx::array& a,
|
||||
|
||||
@@ -775,6 +775,39 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
|
||||
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
|
||||
|
||||
def test_median(self):
|
||||
x = mx.array([])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.median(x, axis=0)
|
||||
x = mx.array([0, 1, 2, 3, 4])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.median(x, axis=(0, 1))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.median(x, axis=(0, 0))
|
||||
|
||||
out = mx.median(x)
|
||||
self.assertEqual(out.shape, ())
|
||||
self.assertEqual(out.item(), 2)
|
||||
out = mx.median(x, keepdims=True)
|
||||
self.assertEqual(out.shape, (1,))
|
||||
|
||||
x = mx.array([0, 1, 2, 3, 4, 5])
|
||||
out = mx.median(x)
|
||||
self.assertEqual(out.item(), 2.5)
|
||||
|
||||
x = mx.random.normal((5, 5, 5, 5))
|
||||
out = mx.median(x, axis=(0, 2), keepdims=True)
|
||||
out_np = np.median(x, axis=(0, 2), keepdims=True)
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
out = mx.median(x, axis=(1, 3), keepdims=True)
|
||||
out_np = np.median(x, axis=(1, 3), keepdims=True)
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
out = mx.median(x, axis=(0, 1, 3), keepdims=True)
|
||||
out_np = np.median(x, axis=(0, 1, 3), keepdims=True)
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
def test_var(self):
|
||||
x = mx.array(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user