Stop matrix copies with new attention kernel (#1639)

This commit is contained in:
Jagrit Digani 2024-12-02 14:12:38 -08:00 committed by GitHub
parent 1445dcaa60
commit 9d40e521d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -288,11 +288,9 @@ void ScaledDotProductAttention::eval_gpu(
strides[0] == strides[1] * shape[1]; strides[0] == strides[1] * shape[1];
}; };
// Checks that the last two dims are row contiguous. // Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) { auto is_matrix_contiguous = [](const array& arr) {
auto& strides = arr.strides(); return arr.strides(3) == 1;
auto& shape = arr.shape();
return strides[3] == 1 && strides[2] == shape[3];
}; };
// We are in vector mode ie single query // We are in vector mode ie single query