mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
Fix reduce sum/prod overflow (#2477)
This commit is contained in:
parent
8ae4a76308
commit
fce53b61d6
@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
case uint8:
|
case uint8:
|
||||||
|
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
case uint16:
|
|
||||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
case uint32:
|
|
||||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
case uint64:
|
|
||||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
|
@ -134,6 +134,10 @@ instantiate_and_or(and, And)
|
|||||||
instantiate_and_or(or, Or)
|
instantiate_and_or(or, Or)
|
||||||
|
|
||||||
#define instantiate_sum_prod(name, op) \
|
#define instantiate_sum_prod(name, op) \
|
||||||
|
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
|
||||||
|
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
|
||||||
|
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||||
|
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||||
|
@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
|
|||||||
const std::string& op_name) {
|
const std::string& op_name) {
|
||||||
if (op_name == "sum" || op_name == "prod") {
|
if (op_name == "sum" || op_name == "prod") {
|
||||||
if (issubdtype(in.dtype(), integer)) {
|
if (issubdtype(in.dtype(), integer)) {
|
||||||
switch (in.dtype().size()) {
|
switch (in.dtype()) {
|
||||||
case 1:
|
case uint8:
|
||||||
|
return {uint8, uint32};
|
||||||
|
case uint16:
|
||||||
|
return {uint16, uint32};
|
||||||
|
case uint32:
|
||||||
|
return {uint32, uint32};
|
||||||
|
case uint64:
|
||||||
|
return {uint64, uint64};
|
||||||
|
case int8:
|
||||||
return {int8, int32};
|
return {int8, int32};
|
||||||
case 2:
|
case int16:
|
||||||
return {int16, int32};
|
return {int16, int32};
|
||||||
case 4:
|
case int32:
|
||||||
return {int32, int32};
|
return {int32, int32};
|
||||||
case 8:
|
case int64:
|
||||||
return {int64, int64};
|
return {int64, int64};
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported integer type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (in.dtype() == bool_) {
|
if (in.dtype() == bool_) {
|
||||||
|
@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
|
|||||||
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sum and prod overflow
|
||||||
|
{
|
||||||
|
auto a = full({256, 2, 2}, 1u, uint8);
|
||||||
|
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
|
||||||
|
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||||
|
|
||||||
|
a = full({65535, 2, 2}, 1u, uint16);
|
||||||
|
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
|
||||||
|
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test gpu reduce with axes") {
|
||||||
// reducing only some axes and irregular layouts
|
// reducing only some axes and irregular layouts
|
||||||
{
|
{
|
||||||
array a(1.0f);
|
array a(1.0f);
|
||||||
|
@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
|
|||||||
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
|
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test unsigned sum
|
||||||
|
{
|
||||||
|
const int num_elems = 1000;
|
||||||
|
|
||||||
|
auto x = astype(full({num_elems}, 255), uint8);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
|
||||||
|
|
||||||
|
x = astype(full({num_elems}, 65535), uint16);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
|
||||||
|
|
||||||
|
x = full({3, 3, 3}, 10000, uint32);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
|
||||||
|
|
||||||
|
x = full({3, 3, 3}, 10000, uint64);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
|
||||||
|
}
|
||||||
|
|
||||||
// Test prod
|
// Test prod
|
||||||
{
|
{
|
||||||
auto x = array({});
|
auto x = array({});
|
||||||
@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
|
|||||||
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
|
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test unsigned prod
|
||||||
|
{
|
||||||
|
auto x = array({255, 255}, {2}, uint8);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
|
||||||
|
|
||||||
|
x = array({65535, 2}, {2}, uint16);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
|
||||||
|
|
||||||
|
x = array({100000, 2}, {2}, uint32);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
|
||||||
|
|
||||||
|
x = array({100000, 2}, {2}, uint64);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
|
||||||
|
}
|
||||||
|
|
||||||
// Test all
|
// Test all
|
||||||
{
|
{
|
||||||
auto x = array({});
|
auto x = array({});
|
||||||
|
Loading…
Reference in New Issue
Block a user