mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
2d gather specialization (#1339)
This commit is contained in:
parent
82db84b899
commit
df3233454d
@ -95,11 +95,21 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
// Launch 2D grid of threads: indices x slice
|
||||
size_t dim0 = out.size() / slice_size;
|
||||
size_t dim1 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
|
||||
// Launch 3D grid of threads
|
||||
// First two dimensions for the indices, the last one for the slice
|
||||
size_t dim0 = 1;
|
||||
size_t dim1 = 1;
|
||||
if (nidx) {
|
||||
if (inputs[1].ndim() >= 1) {
|
||||
dim0 = inputs[1].shape(0);
|
||||
}
|
||||
if (inputs[1].ndim() >= 2) {
|
||||
dim1 = inputs[1].size() / dim0;
|
||||
}
|
||||
}
|
||||
size_t dim2 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
|
@ -13,8 +13,8 @@ constexpr std::string_view gather_kernels = R"(
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
{4}
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {{
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
|
@ -14,32 +14,36 @@ METAL_FUNC void gather_impl(
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc += elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
size_t out_idx = index.z;
|
||||
if (IDX_NDIM == 1) {
|
||||
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
|
||||
} else if (IDX_NDIM >= 2) {
|
||||
out_idx +=
|
||||
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
|
||||
}
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user