fix scatter + test (#1202)

* fix scatter + test

* fix test warnings

* fix metal validation
This commit is contained in:
Awni Hannun
2024-06-11 14:35:12 -07:00
committed by GitHub
parent 709ccc6800
commit df964132fb
5 changed files with 54 additions and 9 deletions

View File

@@ -293,7 +293,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
size_t out_ndim = out.ndim();
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
if (upd_ndim <= 1) {
// Placeholder so Metal doesn't compalain
int shape_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 6);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
// Set index buffers
for (int i = 0; i < nidx; ++i) {

View File

@@ -38,12 +38,24 @@ constexpr std::string_view scatter_kernels = R"(
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
updates,
out,
out_shape,
out_strides,
out_ndim,
upd_shape,
upd_ndim,
upd_size,
idx_buffers,
gid);
}}
[[kernel]] void scatter{0}_{4}(

View File

@@ -10,7 +10,10 @@ METAL_FUNC void scatter_1d_index_impl(
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
@@ -21,7 +24,14 @@ METAL_FUNC void scatter_1d_index_impl(
out_idx += idx_val * out_strides[i];
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
if (upd_ndim > 1) {
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
out_idx += out_offset;
} else {
out_idx += gid.x;
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
}
template <typename T, typename IdxT, typename Op, int NIDX>