mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 10:46:39 +08:00
Stop matrix copies with new attention kernel (#1639)
This commit is contained in:
parent
1445dcaa60
commit
9d40e521d7
@ -288,11 +288,9 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
strides[0] == strides[1] * shape[1];
|
||||
};
|
||||
|
||||
// Checks that the last two dims are row contiguous.
|
||||
// Checks that the headdim dimension has stride 1.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return strides[3] == 1 && strides[2] == shape[3];
|
||||
return arr.strides(3) == 1;
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
|
Loading…
Reference in New Issue
Block a user