From fce53b61d6a93f3a86e74b0a8a3bc86547228c11 Mon Sep 17 00:00:00 2001 From: Abe Leininger <95333017+abeleinin@users.noreply.github.com> Date: Tue, 12 Aug 2025 02:05:33 -0500 Subject: [PATCH] Fix reduce sum/prod overflow (#2477) --- mlx/backend/cpu/reduce.cpp | 14 ++++++++--- mlx/backend/metal/kernels/reduce.metal | 4 ++++ mlx/backend/metal/reduce.cpp | 20 ++++++++++++---- tests/gpu_tests.cpp | 13 +++++++++++ tests/ops_tests.cpp | 32 ++++++++++++++++++++++++++ 5 files changed, 75 insertions(+), 8 deletions(-) diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index 8febbd050..41764f4c8 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { switch (in.dtype()) { case bool_: case uint8: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case uint16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case uint32: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case uint64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; case int8: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int16: - case uint16: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int32: - case uint32: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int64: - case uint64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case float16: diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 428f65012..de5dfbad7 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -134,6 +134,10 @@ instantiate_and_or(and, And) instantiate_and_or(or, Or) #define instantiate_sum_prod(name, op) \ + instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \ + instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \ + instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \ + instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \ instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \ instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \ instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \ diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 3ae766ba9..504943d82 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -247,15 +247,25 @@ std::pair remap_reduce_types( const std::string& op_name) { if (op_name == "sum" || op_name == "prod") { if (issubdtype(in.dtype(), integer)) { - switch (in.dtype().size()) { - case 1: + switch (in.dtype()) { + case uint8: + return {uint8, uint32}; + case uint16: + return {uint16, uint32}; + case uint32: + return {uint32, uint32}; + case uint64: + return {uint64, uint64}; + case int8: return {int8, int32}; - case 2: + case int16: return {int16, int32}; - case 4: + case int32: return {int32, int32}; - case 8: + case int64: return {int64, int64}; + default: + throw std::runtime_error("Unsupported integer type"); } } if (in.dtype() == bool_) { diff --git a/tests/gpu_tests.cpp b/tests/gpu_tests.cpp index f0ef969cf..625cbf552 100644 --- a/tests/gpu_tests.cpp +++ b/tests/gpu_tests.cpp @@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") { CHECK_EQ(prod(a, Device::gpu).item(), 1); } + // sum and prod overflow + { + auto a = full({256, 2, 2}, 1u, uint8); + CHECK_EQ(sum(a, Device::gpu).item(), 256 * 4); + CHECK_EQ(prod(a, Device::gpu).item(), 1); + + a = full({65535, 2, 2}, 1u, uint16); + CHECK_EQ(sum(a, Device::gpu).item(), 65535 * 4); + CHECK_EQ(prod(a, Device::gpu).item(), 1); + } +} + +TEST_CASE("test gpu reduce with axes") { // reducing only some axes and irregular layouts { array a(1.0f); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 969bc2ba7..17207efd4 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") { CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item()); } + // Test unsigned sum + { + const int num_elems = 1000; + + auto x = astype(full({num_elems}, 255), uint8); + CHECK_EQ(sum(x, Device::cpu).item(), 255 * num_elems); + + x = astype(full({num_elems}, 65535), uint16); + CHECK_EQ(sum(x, Device::cpu).item(), 65535 * num_elems); + + x = full({3, 3, 3}, 10000, uint32); + CHECK_EQ(sum(x, Device::cpu).item(), 270000); + + x = full({3, 3, 3}, 10000, uint64); + CHECK_EQ(sum(x, Device::cpu).item(), 270000); + } + // Test prod { auto x = array({}); @@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") { CHECK(array_equal(prod(x, 1), array({true, false})).item()); } + // Test unsigned prod + { + auto x = array({255, 255}, {2}, uint8); + CHECK_EQ(prod(x, Device::cpu).item(), 65025); + + x = array({65535, 2}, {2}, uint16); + CHECK_EQ(prod(x, Device::cpu).item(), 131070); + + x = array({100000, 2}, {2}, uint32); + CHECK_EQ(prod(x, Device::cpu).item(), 200000); + + x = array({100000, 2}, {2}, uint64); + CHECK_EQ(prod(x, Device::cpu).item(), 200000); + } + // Test all { auto x = array({});