Fix reduce sum/prod overflow (#2477)

This commit is contained in:
Abe Leininger 2025-08-12 02:05:33 -05:00 committed by GitHub
parent 8ae4a76308
commit fce53b61d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 75 additions and 8 deletions

View File

@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:

View File

@ -134,6 +134,10 @@ instantiate_and_or(and, And)
instantiate_and_or(or, Or)
#define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \

View File

@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) {
switch (in.dtype().size()) {
case 1:
switch (in.dtype()) {
case uint8:
return {uint8, uint32};
case uint16:
return {uint16, uint32};
case uint32:
return {uint32, uint32};
case uint64:
return {uint64, uint64};
case int8:
return {int8, int32};
case 2:
case int16:
return {int16, int32};
case 4:
case int32:
return {int32, int32};
case 8:
case int64:
return {int64, int64};
default:
throw std::runtime_error("Unsupported integer type");
}
}
if (in.dtype() == bool_) {

View File

@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
}
// sum and prod overflow
{
auto a = full({256, 2, 2}, 1u, uint8);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
a = full({65535, 2, 2}, 1u, uint16);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
}
}
TEST_CASE("test gpu reduce with axes") {
// reducing only some axes and irregular layouts
{
array a(1.0f);

View File

@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
}
// Test unsigned sum
{
const int num_elems = 1000;
auto x = astype(full({num_elems}, 255), uint8);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
x = astype(full({num_elems}, 65535), uint16);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
x = full({3, 3, 3}, 10000, uint32);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
x = full({3, 3, 3}, 10000, uint64);
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
}
// Test prod
{
auto x = array({});
@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
}
// Test unsigned prod
{
auto x = array({255, 255}, {2}, uint8);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
x = array({65535, 2}, {2}, uint16);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
x = array({100000, 2}, {2}, uint32);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
x = array({100000, 2}, {2}, uint64);
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
}
// Test all
{
auto x = array({});