From da691257ecec28385506e594f25c716e9e36119c Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 30 Aug 2024 13:32:41 -0700 Subject: [PATCH] Fix overflow in quantize/dequantize (#1379) * add 2d indices to prevent overflow * use nthreads not out size --- mlx/backend/metal/kernels/quantized.h | 31 ++++++++++++++++----------- mlx/backend/metal/quantized.cpp | 13 ++++++++++- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index f4d750e58..4f388b9f3 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1460,7 +1460,8 @@ template device uint8_t* out [[buffer(1)]], device T* scales [[buffer(2)]], device T* biases [[buffer(3)]], - uint index [[thread_position_in_grid]]) { + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { constexpr T eps = T(1e-7); constexpr int simd_size = 32; constexpr int uint8_bits = 8; @@ -1475,8 +1476,9 @@ template group_size % simd_size == 0, "Group size must be divisible by simd size."); - int in_index = index * values_per_reduce; - int out_index = index * writes_per_pack; + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t in_index = offset * values_per_reduce; + size_t out_index = offset * writes_per_pack; T w_thread[values_per_reduce]; T w_min = Limits::max; @@ -1503,7 +1505,7 @@ template T bias = at_zero ? T(0) : edge; // Write out the scales and biases - int gindex = in_index / group_size; + size_t gindex = in_index / group_size; if (in_index % group_size == 0) { scales[gindex] = scale; biases[gindex] = bias; @@ -1542,13 +1544,16 @@ template const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], device uint8_t* out [[buffer(3)]], - uint index [[thread_position_in_grid]]) { + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { constexpr int uint8_bits = 8; constexpr int packs_per_int = uint8_bits / bits; constexpr T n_bins = (1 << bits) - 1; - int in_index = index * packs_per_int; - int gindex = in_index / group_size; + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t in_index = offset * packs_per_int; + size_t gindex = in_index / group_size; + T scale = scales[gindex]; T bias = biases[gindex]; @@ -1562,7 +1567,7 @@ template output += val << (bits * i); } } - out[index] = output; + out[offset] = output; } template @@ -1571,15 +1576,17 @@ template const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], device T* out [[buffer(3)]], - uint index [[thread_position_in_grid]]) { + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { constexpr int uint8_bits = 8; constexpr int packs_per_int = uint8_bits / bits; - int oindex = index * packs_per_int; - int gindex = oindex / group_size; + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t oindex = offset * packs_per_int; + size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; - uint val = w[index]; + uint val = w[offset]; #pragma clang loop unroll(full) for (int i = 0; i < packs_per_int; i++) { diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 17dbd02d1..bd5c19fe3 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -584,8 +584,19 @@ void fast::AffineQuantize::eval_gpu( dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } auto group_dims = MTL::Size(thread_group_size, 1, 1); - auto grid_dims = MTL::Size(nthreads, 1, 1); + bool use_2d = nthreads > UINT_MAX; + auto grid_shape = w.shape(); + if (dequantize_) { + grid_shape.back() *= uint8_per_uint32; + } else { + grid_shape.back() /= per_thread; + } + MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) + : MTL::Size(nthreads, 1, 1); compute_encoder.dispatchThreads(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler(