Reduce specializations (#1607)

* start of reduce specializations

* fix all reduce

* fix many dims

* fix

* non-jit tests clear

* cleanup instantiations

* cpu merges

* change dim specializations

* optimize

* fix jit

* fix jit

* use higher precision for integer sum+prod

* fixes
This commit is contained in:
Awni Hannun
2024-11-21 19:53:00 -08:00
committed by GitHub
parent dcca0d7477
commit 0c5eea226b
14 changed files with 733 additions and 406 deletions

View File

@@ -1615,7 +1615,14 @@ array sum(
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) {
out_type = a.dtype().size() <= 4 ? int32 : int64;
} else if (issubdtype(a.dtype(), unsignedinteger)) {
out_type = a.dtype().size() <= 4 ? uint32 : uint64;
} else if (a.dtype() == bool_) {
out_type = int32;
}
auto out = (is_noop)
? astype(a, out_type, s)
: array(
@@ -1760,11 +1767,19 @@ array prod(
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) {
out_type = a.dtype().size() <= 4 ? int32 : int64;
} else if (issubdtype(a.dtype(), unsignedinteger)) {
out_type = a.dtype().size() <= 4 ? uint32 : uint64;
} else if (a.dtype() == bool_) {
out_type = int32;
}
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
out_type,
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
if (!keepdims) {