mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
fix scatter (#821)
This commit is contained in:
@@ -201,15 +201,12 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
bool upd_col_contiguous = upd.flags().col_contiguous;
|
||||
compute_encoder->setBytes(
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
|
@@ -20,7 +20,6 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
@@ -33,11 +32,7 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
if (!upd_col_contiguous) {
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
} else {
|
||||
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
|
||||
}
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
}
|
||||
|
||||
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
||||
@@ -48,7 +43,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
@@ -60,7 +54,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
upd_size, \
|
||||
upd_col_contiguous, \
|
||||
idx_buffers, \
|
||||
gid); \
|
||||
\
|
||||
@@ -195,7 +188,6 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
|
Reference in New Issue
Block a user