Update block offset adjustment to be in size_t (#1087)

This commit is contained in:
Jagrit Digani 2024-05-08 08:10:23 -07:00 committed by GitHub
parent 9814a2ae12
commit fe96ceee66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 34 additions and 24 deletions

View File

@ -164,10 +164,12 @@ struct GEMMKernel {
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);

View File

@ -100,12 +100,14 @@ template <
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
@ -164,7 +166,7 @@ template <
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
int leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_kernel::gemm_loop(

View File

@ -114,10 +114,12 @@ block_masked_gemm(
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];

View File

@ -65,12 +65,16 @@ template <
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
A += transpose_a ? (c_row + k_start * params->lda)
: (k_start + c_row * params->lda);
B += transpose_b ? (k_start + c_col * params->ldb)
: (c_col + k_start * params->ldb);
C += (params->split_k_partition_stride * tid_z) +
(c_row * params->ldc + c_col);
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
const size_t k_start_long = size_t(k_start);
A += transpose_a ? (c_row_long + k_start_long * params->lda)
: (k_start_long + c_row_long * params->lda);
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
: (c_col_long + k_start_long * params->ldb);
C += (size_t(params->split_k_partition_stride) * tid_z) +
(c_row_long * params->ldc + c_col_long);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
@ -249,10 +253,10 @@ template <
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
int offset = 0;
size_t offset = 0;
AccT out = 0;
for (int i = 0; i < k_partitions; i++) {
@ -281,11 +285,11 @@ template <
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
C += gid.x * fdc + gid.y * ldc;
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
int offset = 0;
size_t offset = 0;
AccT out = 0;
for (int i = 0; i < k_partitions; i++) {