mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 18:56:39 +08:00
Stop matrix copies with new attention kernel (#1639)
This commit is contained in:
parent
1445dcaa60
commit
9d40e521d7
@ -288,11 +288,9 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
strides[0] == strides[1] * shape[1];
|
strides[0] == strides[1] * shape[1];
|
||||||
};
|
};
|
||||||
|
|
||||||
// Checks that the last two dims are row contiguous.
|
// Checks that the headdim dimension has stride 1.
|
||||||
auto is_matrix_contiguous = [](const array& arr) {
|
auto is_matrix_contiguous = [](const array& arr) {
|
||||||
auto& strides = arr.strides();
|
return arr.strides(3) == 1;
|
||||||
auto& shape = arr.shape();
|
|
||||||
return strides[3] == 1 && strides[2] == shape[3];
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
|
Loading…
Reference in New Issue
Block a user