add cuda gemv (#2400)

This commit is contained in:
Awni Hannun
2025-07-22 08:24:13 -07:00
committed by GitHub
parent 1e496ddb82
commit d107d8d495
12 changed files with 198 additions and 21 deletions

View File

@@ -34,7 +34,7 @@ __global__ void copy_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]);
}
}