mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Reduce specializations (#1607)
* start of reduce specializations * fix all reduce * fix many dims * fix * non-jit tests clear * cleanup instantiations * cpu merges * change dim specializations * optimize * fix jit * fix jit * use higher precision for integer sum+prod * fixes
This commit is contained in:
19
mlx/ops.cpp
19
mlx/ops.cpp
@@ -1615,7 +1615,14 @@ array sum(
|
||||
}
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
|
||||
Dtype out_type = a.dtype();
|
||||
if (issubdtype(a.dtype(), signedinteger)) {
|
||||
out_type = a.dtype().size() <= 4 ? int32 : int64;
|
||||
} else if (issubdtype(a.dtype(), unsignedinteger)) {
|
||||
out_type = a.dtype().size() <= 4 ? uint32 : uint64;
|
||||
} else if (a.dtype() == bool_) {
|
||||
out_type = int32;
|
||||
}
|
||||
auto out = (is_noop)
|
||||
? astype(a, out_type, s)
|
||||
: array(
|
||||
@@ -1760,11 +1767,19 @@ array prod(
|
||||
}
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
Dtype out_type = a.dtype();
|
||||
if (issubdtype(a.dtype(), signedinteger)) {
|
||||
out_type = a.dtype().size() <= 4 ? int32 : int64;
|
||||
} else if (issubdtype(a.dtype(), unsignedinteger)) {
|
||||
out_type = a.dtype().size() <= 4 ? uint32 : uint64;
|
||||
} else if (a.dtype() == bool_) {
|
||||
out_type = int32;
|
||||
}
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
out_type,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
|
||||
Reference in New Issue
Block a user