Fix threadgroup memory in arg reduce (#723)

This commit is contained in:
Jagrit Digani 2024-02-21 19:42:16 -08:00 committed by GitHub
parent 972d9a3aea
commit 884b4ed43b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 9 deletions

View File

@ -11,8 +11,6 @@ template <typename U>
struct IndexValPair {
uint32_t index;
U val;
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
};
template <typename U>
@ -65,10 +63,10 @@ struct ArgMax {
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>(
return IndexValPair<U>{
simd_shuffle_down(data.index, delta),
simd_shuffle_down(data.val, delta)
);
};
}
@ -82,7 +80,6 @@ template <typename T, typename Op, int N_READS>
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
@ -111,7 +108,9 @@ template <typename T, typename Op, int N_READS>
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
IndexValPair<T> best(0, Op::init);
IndexValPair<T> best{0, Op::init};
threadgroup IndexValPair<T> local_data[32];
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
@ -172,7 +171,6 @@ template <typename T, typename Op, int N_READS>
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \

View File

@ -430,8 +430,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
compute_encoder->setThreadgroupMemoryLength(
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}