diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 7ba6f5c05..6288025e4 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -95,11 +95,21 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { slice_size *= s; } - // Launch 2D grid of threads: indices x slice - size_t dim0 = out.size() / slice_size; - size_t dim1 = slice_size; - auto group_dims = get_block_dims(dim0, dim1, 1); - MTL::Size grid_dims = MTL::Size(dim0, dim1, 1); + // Launch 3D grid of threads + // First two dimensions for the indices, the last one for the slice + size_t dim0 = 1; + size_t dim1 = 1; + if (nidx) { + if (inputs[1].ndim() >= 1) { + dim0 = inputs[1].shape(0); + } + if (inputs[1].ndim() >= 2) { + dim1 = inputs[1].size() / dim0; + } + } + size_t dim2 = slice_size; + auto group_dims = get_block_dims(dim0, dim1, dim2); + MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2); // Collect all idx shapes and strides into one place std::vector idx_shapes; diff --git a/mlx/backend/metal/jit/indexing.h b/mlx/backend/metal/jit/indexing.h index 2227fa2f1..9c5ec6213 100644 --- a/mlx/backend/metal/jit/indexing.h +++ b/mlx/backend/metal/jit/indexing.h @@ -13,8 +13,8 @@ constexpr std::string_view gather_kernels = R"( const constant size_t* idx_strides [[buffer(8)]], const constant int& idx_ndim [[buffer(9)]], {4} - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) {{ + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) {{ Indices<{2}, {3}> idxs{{ {{ {5} }}, idx_shapes, idx_strides, idx_ndim}}; diff --git a/mlx/backend/metal/kernels/gather.h b/mlx/backend/metal/kernels/gather.h index 34f807f3d..4ee529974 100644 --- a/mlx/backend/metal/kernels/gather.h +++ b/mlx/backend/metal/kernels/gather.h @@ -14,32 +14,36 @@ METAL_FUNC void gather_impl( const constant int* slice_sizes [[buffer(5)]], const constant int* axes [[buffer(6)]], const thread Indices& indices, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto ind_idx = index.x; - auto ind_offset = index.y; - + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { size_t src_idx = 0; for (int i = 0; i < NIDX; ++i) { size_t idx_loc; if (IDX_NDIM == 0) { idx_loc = 0; } else if (IDX_NDIM == 1) { - idx_loc = ind_idx * indices.strides[indices.ndim * i]; + idx_loc = index.x * indices.strides[indices.ndim * i]; } else { - idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); + idx_loc = index.x * indices.strides[indices.ndim * i]; + idx_loc += elem_to_loc( + index.y, + &indices.shapes[indices.ndim * i + 1], + &indices.strides[indices.ndim * i + 1], + indices.ndim - 1); } auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); src_idx += idx_val * src_strides[ax]; } - auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim); + auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); - size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; + size_t out_idx = index.z; + if (IDX_NDIM == 1) { + out_idx += static_cast(grid_dim.z) * index.x; + } else if (IDX_NDIM >= 2) { + out_idx += + grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); + } out[out_idx] = src[src_offset + src_idx]; }