mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
Fix reduce sum/prod overflow (#2477)
This commit is contained in:
@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
|
@@ -134,6 +134,10 @@ instantiate_and_or(and, And)
|
||||
instantiate_and_or(or, Or)
|
||||
|
||||
#define instantiate_sum_prod(name, op) \
|
||||
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
|
@@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
|
||||
const std::string& op_name) {
|
||||
if (op_name == "sum" || op_name == "prod") {
|
||||
if (issubdtype(in.dtype(), integer)) {
|
||||
switch (in.dtype().size()) {
|
||||
case 1:
|
||||
switch (in.dtype()) {
|
||||
case uint8:
|
||||
return {uint8, uint32};
|
||||
case uint16:
|
||||
return {uint16, uint32};
|
||||
case uint32:
|
||||
return {uint32, uint32};
|
||||
case uint64:
|
||||
return {uint64, uint64};
|
||||
case int8:
|
||||
return {int8, int32};
|
||||
case 2:
|
||||
case int16:
|
||||
return {int16, int32};
|
||||
case 4:
|
||||
case int32:
|
||||
return {int32, int32};
|
||||
case 8:
|
||||
case int64:
|
||||
return {int64, int64};
|
||||
default:
|
||||
throw std::runtime_error("Unsupported integer type");
|
||||
}
|
||||
}
|
||||
if (in.dtype() == bool_) {
|
||||
|
Reference in New Issue
Block a user