mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix attention for large sizes (#2903)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
This commit is contained in:
@@ -347,7 +347,7 @@ template <
|
||||
MMAFrag_mask_t::load_safe(
|
||||
mfrag,
|
||||
mask,
|
||||
int(mask_params->M_strides[2]),
|
||||
int64_t(mask_params->M_strides[2]),
|
||||
Int<1>{},
|
||||
params->qL,
|
||||
params->kL,
|
||||
|
||||
@@ -346,7 +346,7 @@ template <
|
||||
MSubTile mfrag;
|
||||
mfrag.load_safe(
|
||||
mask,
|
||||
int(mask_params->M_strides[2]),
|
||||
int64_t(mask_params->M_strides[2]),
|
||||
Int<1>{},
|
||||
params->qL,
|
||||
params->kL,
|
||||
|
||||
@@ -105,17 +105,20 @@ struct BaseMMAFrag<T, 8, 8> {
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
src += off_x * str_x + off_y * str_y;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[i * kElemCols + j] =
|
||||
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
|
||||
dst[i * kElemCols + j] = static_cast<T>(src[0]);
|
||||
} else {
|
||||
dst[i * kElemCols + j] = T(0);
|
||||
}
|
||||
src += str_y;
|
||||
}
|
||||
src -= kElemCols * str_y;
|
||||
src += str_x;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user