mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Fix threadgroup memory in arg reduce (#723)
This commit is contained in:
parent
972d9a3aea
commit
884b4ed43b
@ -11,8 +11,6 @@ template <typename U>
|
|||||||
struct IndexValPair {
|
struct IndexValPair {
|
||||||
uint32_t index;
|
uint32_t index;
|
||||||
U val;
|
U val;
|
||||||
|
|
||||||
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename U>
|
template <typename U>
|
||||||
@ -65,10 +63,10 @@ struct ArgMax {
|
|||||||
|
|
||||||
template <typename U>
|
template <typename U>
|
||||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||||
return IndexValPair<U>(
|
return IndexValPair<U>{
|
||||||
simd_shuffle_down(data.index, delta),
|
simd_shuffle_down(data.index, delta),
|
||||||
simd_shuffle_down(data.val, delta)
|
simd_shuffle_down(data.val, delta)
|
||||||
);
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -82,7 +80,6 @@ template <typename T, typename Op, int N_READS>
|
|||||||
const device size_t& ndim [[buffer(5)]],
|
const device size_t& ndim [[buffer(5)]],
|
||||||
const device size_t& axis_stride [[buffer(6)]],
|
const device size_t& axis_stride [[buffer(6)]],
|
||||||
const device size_t& axis_size [[buffer(7)]],
|
const device size_t& axis_size [[buffer(7)]],
|
||||||
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
|
|
||||||
uint gid [[thread_position_in_grid]],
|
uint gid [[thread_position_in_grid]],
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
uint lsize [[threads_per_threadgroup]],
|
uint lsize [[threads_per_threadgroup]],
|
||||||
@ -111,7 +108,9 @@ template <typename T, typename Op, int N_READS>
|
|||||||
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
|
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
|
||||||
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
|
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
|
||||||
|
|
||||||
IndexValPair<T> best(0, Op::init);
|
IndexValPair<T> best{0, Op::init};
|
||||||
|
|
||||||
|
threadgroup IndexValPair<T> local_data[32];
|
||||||
|
|
||||||
// Loop over the reduction axis in lsize*N_READS buckets
|
// Loop over the reduction axis in lsize*N_READS buckets
|
||||||
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||||
@ -172,7 +171,6 @@ template <typename T, typename Op, int N_READS>
|
|||||||
const device size_t& ndim [[buffer(5)]], \
|
const device size_t& ndim [[buffer(5)]], \
|
||||||
const device size_t& axis_stride [[buffer(6)]], \
|
const device size_t& axis_stride [[buffer(6)]], \
|
||||||
const device size_t& axis_size [[buffer(7)]], \
|
const device size_t& axis_size [[buffer(7)]], \
|
||||||
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
|
|
||||||
uint gid [[thread_position_in_grid]], \
|
uint gid [[thread_position_in_grid]], \
|
||||||
uint lid [[thread_position_in_threadgroup]], \
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
uint lsize [[threads_per_threadgroup]], \
|
uint lsize [[threads_per_threadgroup]], \
|
||||||
|
@ -430,8 +430,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||||
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||||
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||||
compute_encoder->setThreadgroupMemoryLength(
|
|
||||||
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
|
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user