mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 01:46:37 +08:00
Fix reduce sum/prod overflow (#2477)
This commit is contained in:
parent
8ae4a76308
commit
fce53b61d6
@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
|
@ -134,6 +134,10 @@ instantiate_and_or(and, And)
|
||||
instantiate_and_or(or, Or)
|
||||
|
||||
#define instantiate_sum_prod(name, op) \
|
||||
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
|
@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
|
||||
const std::string& op_name) {
|
||||
if (op_name == "sum" || op_name == "prod") {
|
||||
if (issubdtype(in.dtype(), integer)) {
|
||||
switch (in.dtype().size()) {
|
||||
case 1:
|
||||
switch (in.dtype()) {
|
||||
case uint8:
|
||||
return {uint8, uint32};
|
||||
case uint16:
|
||||
return {uint16, uint32};
|
||||
case uint32:
|
||||
return {uint32, uint32};
|
||||
case uint64:
|
||||
return {uint64, uint64};
|
||||
case int8:
|
||||
return {int8, int32};
|
||||
case 2:
|
||||
case int16:
|
||||
return {int16, int32};
|
||||
case 4:
|
||||
case int32:
|
||||
return {int32, int32};
|
||||
case 8:
|
||||
case int64:
|
||||
return {int64, int64};
|
||||
default:
|
||||
throw std::runtime_error("Unsupported integer type");
|
||||
}
|
||||
}
|
||||
if (in.dtype() == bool_) {
|
||||
|
@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
|
||||
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
||||
}
|
||||
|
||||
// sum and prod overflow
|
||||
{
|
||||
auto a = full({256, 2, 2}, 1u, uint8);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||
|
||||
a = full({65535, 2, 2}, 1u, uint16);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test gpu reduce with axes") {
|
||||
// reducing only some axes and irregular layouts
|
||||
{
|
||||
array a(1.0f);
|
||||
|
@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
|
||||
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
|
||||
}
|
||||
|
||||
// Test unsigned sum
|
||||
{
|
||||
const int num_elems = 1000;
|
||||
|
||||
auto x = astype(full({num_elems}, 255), uint8);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
|
||||
|
||||
x = astype(full({num_elems}, 65535), uint16);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
|
||||
|
||||
x = full({3, 3, 3}, 10000, uint32);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
|
||||
|
||||
x = full({3, 3, 3}, 10000, uint64);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
|
||||
}
|
||||
|
||||
// Test prod
|
||||
{
|
||||
auto x = array({});
|
||||
@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
|
||||
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
|
||||
}
|
||||
|
||||
// Test unsigned prod
|
||||
{
|
||||
auto x = array({255, 255}, {2}, uint8);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
|
||||
|
||||
x = array({65535, 2}, {2}, uint16);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
|
||||
|
||||
x = array({100000, 2}, {2}, uint32);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
|
||||
|
||||
x = array({100000, 2}, {2}, uint64);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
|
||||
}
|
||||
|
||||
// Test all
|
||||
{
|
||||
auto x = array({});
|
||||
|
Loading…
Reference in New Issue
Block a user