mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
158
mlx/ops.cpp
158
mlx/ops.cpp
@@ -16,9 +16,11 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool>
|
||||
compute_reduce_shape(
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& shape) {
|
||||
bool is_noop = true;
|
||||
std::set<int> axes_set;
|
||||
auto ndim = shape.size();
|
||||
for (auto ax : axes) {
|
||||
@@ -35,15 +37,18 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
throw std::invalid_argument("Duplicate axes detected in reduction.");
|
||||
}
|
||||
std::vector<int> out_shape;
|
||||
std::vector<int> squeezed_shape;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
if (axes_set.count(i) == 0) {
|
||||
out_shape.push_back(shape[i]);
|
||||
squeezed_shape.push_back(shape[i]);
|
||||
} else {
|
||||
out_shape.push_back(1);
|
||||
}
|
||||
is_noop &= (out_shape.back() == shape[i]);
|
||||
}
|
||||
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
|
||||
return {out_shape, sorted_axes};
|
||||
return {out_shape, sorted_axes, squeezed_shape, is_noop};
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
@@ -1502,17 +1507,17 @@ array all(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.empty()) {
|
||||
return astype(a, bool_, s);
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? astype(a, bool_, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1536,17 +1541,17 @@ array any(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.empty()) {
|
||||
return astype(a, bool_, s);
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? astype(a, bool_, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1573,15 +1578,18 @@ array sum(
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
|
||||
auto out = array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
||||
{a});
|
||||
auto out = (is_noop)
|
||||
? astype(a, out_type, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1715,14 +1723,17 @@ array prod(
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1749,17 +1760,17 @@ array max(
|
||||
if (a.size() == 0) {
|
||||
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
|
||||
}
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1789,14 +1800,17 @@ array min(
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1829,15 +1843,18 @@ array argmin(
|
||||
throw std::invalid_argument(
|
||||
"[argmin] Cannot argmin reduce zero size array.");
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape({axis}, a.shape());
|
||||
auto out = (is_noop)
|
||||
? zeros(out_shape, uint32, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1862,15 +1879,18 @@ array argmax(
|
||||
throw std::invalid_argument(
|
||||
"[argmax] Cannot argmax reduce zero size array.");
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape({axis}, a.shape());
|
||||
auto out = (is_noop)
|
||||
? zeros(out_shape, uint32, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user