diff --git a/mlx/backend/metal/kernels/steel/gemm/gemm.h b/mlx/backend/metal/kernels/steel/gemm/gemm.h index 3ce44f941..bbe1d96cc 100644 --- a/mlx/backend/metal/kernels/steel/gemm/gemm.h +++ b/mlx/backend/metal/kernels/steel/gemm/gemm.h @@ -164,10 +164,12 @@ struct GEMMKernel { // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); - A += transpose_a ? c_row : c_row * params->lda; - B += transpose_b ? c_col * params->ldb : c_col; - D += c_row * params->ldd + c_col; + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal index 5989c5602..848ff1b81 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal @@ -100,12 +100,14 @@ template < // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); - A += transpose_a ? c_row : c_row * params->lda; - B += transpose_b ? c_col * params->ldb : c_col; - D += c_row * params->ldd + c_col; + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; - C += c_row * addmm_params->ldc + c_col * addmm_params->fdc; + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); @@ -164,7 +166,7 @@ template < else { // Loop over K - unaligned case short tgp_bm = min(BM, params->M - c_row); short tgp_bn = min(BN, params->N - c_col); - short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + int leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; if (tgp_bm == BM && tgp_bn == BN) { gemm_kernel::gemm_loop( diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal index 24710f5fd..586595782 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal @@ -114,10 +114,12 @@ block_masked_gemm( // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); - A += transpose_a ? c_row : c_row * params->lda; - B += transpose_b ? c_col * params->ldb : c_col; - D += c_row * params->ldd + c_col; + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal index f99149569..e5b279f4e 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal @@ -65,12 +65,16 @@ template < const int c_col = tid_x * BN; const int k_start = params->split_k_partition_size * tid_z; - A += transpose_a ? (c_row + k_start * params->lda) - : (k_start + c_row * params->lda); - B += transpose_b ? (k_start + c_col * params->ldb) - : (c_col + k_start * params->ldb); - C += (params->split_k_partition_stride * tid_z) + - (c_row * params->ldc + c_col); + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + const size_t k_start_long = size_t(k_start); + + A += transpose_a ? (c_row_long + k_start_long * params->lda) + : (k_start_long + c_row_long * params->lda); + B += transpose_b ? (k_start_long + c_col_long * params->ldb) + : (c_col_long + k_start_long * params->ldb); + C += (size_t(params->split_k_partition_stride) * tid_z) + + (c_row_long * params->ldc + c_col_long); // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); @@ -249,10 +253,10 @@ template < const constant int& ldd [[buffer(4)]], uint2 gid [[thread_position_in_grid]]) { // Ajust D and C - D += gid.x + gid.y * ldd; - C_split += gid.x + gid.y * ldd; + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); - int offset = 0; + size_t offset = 0; AccT out = 0; for (int i = 0; i < k_partitions; i++) { @@ -281,11 +285,11 @@ template < const constant float& beta [[buffer(9)]], uint2 gid [[thread_position_in_grid]]) { // Ajust D and C - C += gid.x * fdc + gid.y * ldc; - D += gid.x + gid.y * ldd; - C_split += gid.x + gid.y * ldd; + C += gid.x * size_t(fdc) + gid.y * size_t(ldc); + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); - int offset = 0; + size_t offset = 0; AccT out = 0; for (int i = 0; i < k_partitions; i++) {