add cuda gemv (#2400)

This commit is contained in:
Awni Hannun
2025-07-22 08:24:13 -07:00
committed by GitHub
parent 1e496ddb82
commit d107d8d495
12 changed files with 198 additions and 21 deletions

View File

@@ -76,7 +76,7 @@ __global__ void ternary_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d(
auto [a_idx, b_idx, c_idx] = elem_to_loc(
index,
shape.data(),
a_strides.data(),