mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Update block offset adjustment to be in size_t (#1087)
This commit is contained in:
parent
9814a2ae12
commit
fe96ceee66
@ -164,10 +164,12 @@ struct GEMMKernel {
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
|
@ -100,12 +100,14 @@ template <
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
|
||||
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
@ -164,7 +166,7 @@ template <
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
int leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
|
@ -114,10 +114,12 @@ block_masked_gemm(
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
||||
|
||||
|
@ -65,12 +65,16 @@ template <
|
||||
const int c_col = tid_x * BN;
|
||||
const int k_start = params->split_k_partition_size * tid_z;
|
||||
|
||||
A += transpose_a ? (c_row + k_start * params->lda)
|
||||
: (k_start + c_row * params->lda);
|
||||
B += transpose_b ? (k_start + c_col * params->ldb)
|
||||
: (c_col + k_start * params->ldb);
|
||||
C += (params->split_k_partition_stride * tid_z) +
|
||||
(c_row * params->ldc + c_col);
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
const size_t k_start_long = size_t(k_start);
|
||||
|
||||
A += transpose_a ? (c_row_long + k_start_long * params->lda)
|
||||
: (k_start_long + c_row_long * params->lda);
|
||||
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
|
||||
: (c_col_long + k_start_long * params->ldb);
|
||||
C += (size_t(params->split_k_partition_stride) * tid_z) +
|
||||
(c_row_long * params->ldc + c_col_long);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
@ -249,10 +253,10 @@ template <
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
// Ajust D and C
|
||||
D += gid.x + gid.y * ldd;
|
||||
C_split += gid.x + gid.y * ldd;
|
||||
D += gid.x + gid.y * size_t(ldd);
|
||||
C_split += gid.x + gid.y * size_t(ldd);
|
||||
|
||||
int offset = 0;
|
||||
size_t offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for (int i = 0; i < k_partitions; i++) {
|
||||
@ -281,11 +285,11 @@ template <
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
// Ajust D and C
|
||||
C += gid.x * fdc + gid.y * ldc;
|
||||
D += gid.x + gid.y * ldd;
|
||||
C_split += gid.x + gid.y * ldd;
|
||||
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
|
||||
D += gid.x + gid.y * size_t(ldd);
|
||||
C_split += gid.x + gid.y * size_t(ldd);
|
||||
|
||||
int offset = 0;
|
||||
size_t offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for (int i = 0; i < k_partitions; i++) {
|
||||
|
Loading…
Reference in New Issue
Block a user