From 3e3a4cc78d12a20cc19f2744faa26776b1c759ce Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Tue, 18 Nov 2025 13:58:57 -0800 Subject: [PATCH] Fix for edge checking bug in matmul --- mlx/backend/metal/kernels/steel/gemm/gemm_nax.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h b/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h index 3cd20d7b9..e9b69a200 100644 --- a/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +++ b/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h @@ -69,16 +69,16 @@ auto gemm_loop( if constexpr (kAlignedM) { Atile.load(A + A_offset, params->lda); } else { - const short rmax = transpose_a ? UK : sgp_sm; - const short cmax = transpose_a ? sgp_sm : UK; + const short rmax = transpose_a ? SK : sgp_sm; + const short cmax = transpose_a ? sgp_sm : SK; Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax)); } if constexpr (kAlignedN) { Btile.load(B + B_offset, params->ldb); } else { - const short rmax = transpose_b ? sgp_sn : UK; - const short cmax = transpose_b ? UK : sgp_sn; + const short rmax = transpose_b ? sgp_sn : SK; + const short cmax = transpose_b ? SK : sgp_sn; Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax)); }