mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix threadgroup memory in arg reduce (#723)
This commit is contained in:
parent
972d9a3aea
commit
884b4ed43b
@ -11,8 +11,6 @@ template <typename U>
|
||||
struct IndexValPair {
|
||||
uint32_t index;
|
||||
U val;
|
||||
|
||||
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
@ -65,10 +63,10 @@ struct ArgMax {
|
||||
|
||||
template <typename U>
|
||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||
return IndexValPair<U>(
|
||||
return IndexValPair<U>{
|
||||
simd_shuffle_down(data.index, delta),
|
||||
simd_shuffle_down(data.val, delta)
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +80,6 @@ template <typename T, typename Op, int N_READS>
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
const device size_t& axis_stride [[buffer(6)]],
|
||||
const device size_t& axis_size [[buffer(7)]],
|
||||
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
@ -111,7 +108,9 @@ template <typename T, typename Op, int N_READS>
|
||||
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
|
||||
|
||||
IndexValPair<T> best(0, Op::init);
|
||||
IndexValPair<T> best{0, Op::init};
|
||||
|
||||
threadgroup IndexValPair<T> local_data[32];
|
||||
|
||||
// Loop over the reduction axis in lsize*N_READS buckets
|
||||
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||
@ -172,7 +171,6 @@ template <typename T, typename Op, int N_READS>
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
const device size_t& axis_stride [[buffer(6)]], \
|
||||
const device size_t& axis_size [[buffer(7)]], \
|
||||
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
|
@ -430,8 +430,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user