Refactor reductions and fix scatter atomics for large sizes (#1300)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-08-22 16:03:31 -07:00
committed by GitHub
parent f9e00efe31
commit 98b6ce3460
18 changed files with 1584 additions and 1235 deletions

View File

@@ -16,9 +16,11 @@ namespace mlx::core {
namespace {
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool>
compute_reduce_shape(
const std::vector<int>& axes,
const std::vector<int>& shape) {
bool is_noop = true;
std::set<int> axes_set;
auto ndim = shape.size();
for (auto ax : axes) {
@@ -35,15 +37,18 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
throw std::invalid_argument("Duplicate axes detected in reduction.");
}
std::vector<int> out_shape;
std::vector<int> squeezed_shape;
for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]);
squeezed_shape.push_back(shape[i]);
} else {
out_shape.push_back(1);
}
is_noop &= (out_shape.back() == shape[i]);
}
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
return {out_shape, sorted_axes};
return {out_shape, sorted_axes, squeezed_shape, is_noop};
}
Dtype at_least_float(const Dtype& d) {
@@ -1502,17 +1507,17 @@ array all(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
: array(
std::move(out_shape),
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1536,17 +1541,17 @@ array any(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
: array(
std::move(out_shape),
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1573,15 +1578,18 @@ array sum(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
auto out = array(
out_shape,
out_type,
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
auto out = (is_noop)
? astype(a, out_type, s)
: array(
std::move(out_shape),
out_type,
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1715,14 +1723,17 @@ array prod(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1749,17 +1760,17 @@ array max(
if (a.size() == 0) {
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
}
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1789,14 +1800,17 @@ array min(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1829,15 +1843,18 @@ array argmin(
throw std::invalid_argument(
"[argmin] Cannot argmin reduce zero size array.");
}
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
out_shape,
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
: array(
std::move(out_shape),
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1862,15 +1879,18 @@ array argmax(
throw std::invalid_argument(
"[argmax] Cannot argmax reduce zero size array.");
}
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
out_shape,
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
: array(
std::move(out_shape),
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}