add cuda gemv (#2400)

This commit is contained in:
Awni Hannun
2025-07-22 08:24:13 -07:00
committed by GitHub
parent 1e496ddb82
commit d107d8d495
12 changed files with 198 additions and 21 deletions

View File

@@ -37,7 +37,7 @@ __global__ void copy_gg(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_4d(
auto [idx_in, idx_out] = elem_to_loc(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
}