2d gather specialization (#1339)

This commit is contained in:
Awni Hannun
2024-08-22 10:48:24 -07:00
committed by GitHub
parent 82db84b899
commit df3233454d
3 changed files with 34 additions and 20 deletions

View File

@@ -13,8 +13,8 @@ constexpr std::string_view gather_kernels = R"(
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
{4}
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {{
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};