mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 07:18:15 +08:00
Fix reduce sum/prod overflow (#2477)
This commit is contained in:
@@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
|
||||
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
||||
}
|
||||
|
||||
// sum and prod overflow
|
||||
{
|
||||
auto a = full({256, 2, 2}, 1u, uint8);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||
|
||||
a = full({65535, 2, 2}, 1u, uint16);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test gpu reduce with axes") {
|
||||
// reducing only some axes and irregular layouts
|
||||
{
|
||||
array a(1.0f);
|
||||
|
||||
Reference in New Issue
Block a user