Fix reduce sum/prod overflow (#2477)

This commit is contained in:
Abe Leininger 2025-08-12 02:05:33 -05:00 committed by GitHub
parent 8ae4a76308
commit fce53b61d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 75 additions and 8 deletions

View File

@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
case uint8: case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8: case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break; break;
case int16: case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break; break;
case int64: case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break; break;
case float16: case float16:

View File

@ -134,6 +134,10 @@ instantiate_and_or(and, And)
instantiate_and_or(or, Or) instantiate_and_or(or, Or)
#define instantiate_sum_prod(name, op) \ #define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \ instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \ instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \ instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \

View File

@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
const std::string& op_name) { const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") { if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) { if (issubdtype(in.dtype(), integer)) {
switch (in.dtype().size()) { switch (in.dtype()) {
case 1: case uint8:
return {uint8, uint32};
case uint16:
return {uint16, uint32};
case uint32:
return {uint32, uint32};
case uint64:
return {uint64, uint64};
case int8:
return {int8, int32}; return {int8, int32};
case 2: case int16:
return {int16, int32}; return {int16, int32};
case 4: case int32:
return {int32, int32}; return {int32, int32};
case 8: case int64:
return {int64, int64}; return {int64, int64};
default:
throw std::runtime_error("Unsupported integer type");
} }
} }
if (in.dtype() == bool_) { if (in.dtype() == bool_) {

View File

@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1); CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
} }
// sum and prod overflow
{
auto a = full({256, 2, 2}, 1u, uint8);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
a = full({65535, 2, 2}, 1u, uint16);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
}
}
TEST_CASE("test gpu reduce with axes") {
// reducing only some axes and irregular layouts // reducing only some axes and irregular layouts
{ {
array a(1.0f); array a(1.0f);

View File

@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>()); CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
} }
// Test unsigned sum
{
const int num_elems = 1000;
auto x = astype(full({num_elems}, 255), uint8);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
x = astype(full({num_elems}, 65535), uint16);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
x = full({3, 3, 3}, 10000, uint32);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
x = full({3, 3, 3}, 10000, uint64);
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
}
// Test prod // Test prod
{ {
auto x = array({}); auto x = array({});
@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>()); CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
} }
// Test unsigned prod
{
auto x = array({255, 255}, {2}, uint8);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
x = array({65535, 2}, {2}, uint16);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
x = array({100000, 2}, {2}, uint32);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
x = array({100000, 2}, {2}, uint64);
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
}
// Test all // Test all
{ {
auto x = array({}); auto x = array({});