mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
improvements to scatter / gather (#1541)
This commit is contained in:
@@ -25,11 +25,13 @@ METAL_FUNC void gather_impl(
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc += elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
idx_loc += indices.row_contiguous[i]
|
||||
? index.y
|
||||
: elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
|
||||
@@ -9,6 +9,7 @@ struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const constant bool* row_contiguous;
|
||||
const int ndim;
|
||||
};
|
||||
|
||||
|
||||
@@ -4,73 +4,54 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_1d_index_impl(
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; i++) {
|
||||
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
if (upd_ndim > 1) {
|
||||
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
} else {
|
||||
out_idx += gid.x;
|
||||
}
|
||||
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NIDX,
|
||||
bool UPD_ROW_CONTIG,
|
||||
int NWORK>
|
||||
METAL_FUNC void scatter_impl(
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const device T* updates,
|
||||
device mlx_atomic<T>* out,
|
||||
const constant int* upd_shape,
|
||||
const constant size_t* upd_strides,
|
||||
const constant size_t& upd_ndim,
|
||||
const constant size_t& upd_size,
|
||||
const constant int* out_shape,
|
||||
const constant size_t* out_strides,
|
||||
const constant size_t& out_ndim,
|
||||
const constant int* axes,
|
||||
const constant size_t& idx_size,
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto ind_idx = gid.y * NWORK;
|
||||
size_t out_offset = 0;
|
||||
if (upd_size > 1) {
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
out_offset =
|
||||
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
}
|
||||
|
||||
auto upd_idx =
|
||||
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
|
||||
size_t out_idx = out_offset;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = indices.row_contiguous[i]
|
||||
? ind_idx
|
||||
: elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
auto upd_idx = ind_idx * upd_size + gid.x;
|
||||
if constexpr (!UPD_ROW_CONTIG) {
|
||||
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
}
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user