add cuda gemv (#2400)

This commit is contained in:
Awni Hannun
2025-07-22 08:24:13 -07:00
committed by GitHub
parent 1e496ddb82
commit d107d8d495
12 changed files with 198 additions and 21 deletions

View File

@@ -160,7 +160,7 @@ __global__ void binary_two_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];