diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 7397039b5..e24dab819 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -347,7 +347,7 @@ template < MMAFrag_mask_t::load_safe( mfrag, mask, - int(mask_params->M_strides[2]), + int64_t(mask_params->M_strides[2]), Int<1>{}, params->qL, params->kL, diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h index 3a2136f4b..ed073b11d 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -346,7 +346,7 @@ template < MSubTile mfrag; mfrag.load_safe( mask, - int(mask_params->M_strides[2]), + int64_t(mask_params->M_strides[2]), Int<1>{}, params->qL, params->kL, diff --git a/mlx/backend/metal/kernels/steel/attn/mma.h b/mlx/backend/metal/kernels/steel/attn/mma.h index db5127c33..b11a111b5 100644 --- a/mlx/backend/metal/kernels/steel/attn/mma.h +++ b/mlx/backend/metal/kernels/steel/attn/mma.h @@ -105,17 +105,20 @@ struct BaseMMAFrag { LimY lim_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) { + src += off_x * str_x + off_y * str_y; STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if ((off_x + i) < lim_x && (off_y + j) < lim_y) { - dst[i * kElemCols + j] = - static_cast(src[(off_x + i) * str_x + (off_y + j) * str_y]); + dst[i * kElemCols + j] = static_cast(src[0]); } else { dst[i * kElemCols + j] = T(0); } + src += str_y; } + src -= kElemCols * str_y; + src += str_x; } }