From 45df803538ff5d5b1ad1b3c1dd928f7001fa7370 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 30 Apr 2025 09:04:12 -0700 Subject: [PATCH] fix integer overflow in qmm --- mlx/backend/metal/kernels/quantized.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index b2b0d8d8f..ba4fb2426 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl( auto wl = (const device uint8_t*)w; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); @@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl( // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row);