Fix reduce sum/prod overflow (#2477)

This commit is contained in:
Abe Leininger
2025-08-12 02:05:33 -05:00
committed by GitHub
parent 8ae4a76308
commit fce53b61d6
5 changed files with 75 additions and 8 deletions

View File

@@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
}
// sum and prod overflow
{
auto a = full({256, 2, 2}, 1u, uint8);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
a = full({65535, 2, 2}, 1u, uint16);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
}
}
TEST_CASE("test gpu reduce with axes") {
// reducing only some axes and irregular layouts
{
array a(1.0f);