Fix attention for large sizes (#2903)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled

This commit is contained in:
Awni Hannun
2025-12-13 06:54:30 -08:00
committed by GitHub
parent bedefed784
commit 47d2505ea9
3 changed files with 7 additions and 4 deletions

View File

@@ -347,7 +347,7 @@ template <
MMAFrag_mask_t::load_safe(
mfrag,
mask,
int(mask_params->M_strides[2]),
int64_t(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,

View File

@@ -346,7 +346,7 @@ template <
MSubTile mfrag;
mfrag.load_safe(
mask,
int(mask_params->M_strides[2]),
int64_t(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,

View File

@@ -105,17 +105,20 @@ struct BaseMMAFrag<T, 8, 8> {
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
src += off_x * str_x + off_y * str_y;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
dst[i * kElemCols + j] = static_cast<T>(src[0]);
} else {
dst[i * kElemCols + j] = T(0);
}
src += str_y;
}
src -= kElemCols * str_y;
src += str_x;
}
}