mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code, since 64b integer division is very expensive. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
This commit is contained in:
@@ -212,9 +212,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
@@ -317,6 +314,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
|
@@ -187,11 +187,11 @@ template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
const device size_t *out_strides [[buffer(8)]],
|
||||
const device size_t& out_ndim [[buffer(9)]],
|
||||
const device int* axes [[buffer(10)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid / upd_size;
|
||||
auto ind_offset = gid % upd_size;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
@@ -208,7 +208,7 @@ template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ template [[host_name("scatter" name "_" #nindex)]] \
|
||||
const device size_t *out_strides [[buffer(8)]], \
|
||||
const device size_t& out_ndim [[buffer(9)]], \
|
||||
const device int* axes [[buffer(10)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
|
Reference in New Issue
Block a user