Fix matvec vector stride bug (#1168)

This commit is contained in:
Jagrit Digani
2024-05-29 12:18:28 -07:00
committed by GitHub
parent e7a2a3dcd1
commit 9f0df51f8d
2 changed files with 73 additions and 9 deletions

View File

@@ -710,15 +710,19 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1) {
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1) {
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
@@ -729,12 +733,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
}
};
auto [a_transposed, a_cols, a] = check_transpose(a_pre);
auto [b_transposed, b_cols, b] = check_transpose(b_pre);
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
auto [a_transposed, a_cols, a] = check_transpose(a_pre, M == 1);
auto [b_transposed, b_cols, b] = check_transpose(b_pre, N == 1);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions