mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-24 20:28:16 +08:00
2d gather specialization (#1339)
This commit is contained in:
@@ -13,8 +13,8 @@ constexpr std::string_view gather_kernels = R"(
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
{4}
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {{
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
|
Reference in New Issue
Block a user