mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix large arg reduce (#2206)
This commit is contained in:
parent
0359bf02c9
commit
eebe73001a
@ -80,9 +80,10 @@ template <typename T, typename Op, int N_READS = 4>
|
||||
const constant size_t& ndim [[buffer(5)]],
|
||||
const constant int64_t& axis_stride [[buffer(6)]],
|
||||
const constant size_t& axis_size [[buffer(7)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
@ -104,17 +105,18 @@ template <typename T, typename Op, int N_READS = 4>
|
||||
|
||||
// Compute the input/output index. There is one beginning and one output for
|
||||
// the whole threadgroup.
|
||||
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
|
||||
int64_t row_idx = gid.y + static_cast<int64_t>(gsize.y) * gid.z;
|
||||
auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim);
|
||||
|
||||
IndexValPair<T> best{0, Op::init};
|
||||
|
||||
threadgroup IndexValPair<T> local_data[32];
|
||||
|
||||
// Loop over the reduction axis in lsize*N_READS buckets
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {
|
||||
// Read the current value
|
||||
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
|
||||
uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS;
|
||||
uint32_t offset = current_index;
|
||||
const device T* current_in = in + in_idx + current_index * axis_stride;
|
||||
T vals[N_READS];
|
||||
@ -144,7 +146,7 @@ template <typename T, typename Op, int N_READS = 4>
|
||||
}
|
||||
|
||||
// Read the appropriate value from local data and perform one simd reduction
|
||||
uint simd_groups = ceildiv(lsize, simd_size);
|
||||
uint simd_groups = ceildiv(lsize.x, simd_size);
|
||||
if (simd_lane_id < simd_groups) {
|
||||
best = local_data[simd_lane_id];
|
||||
}
|
||||
@ -154,7 +156,7 @@ template <typename T, typename Op, int N_READS = 4>
|
||||
}
|
||||
|
||||
// Finally write the output
|
||||
if (lid == 0) {
|
||||
if (lid.x == 0) {
|
||||
out[out_idx] = best.index;
|
||||
}
|
||||
}
|
||||
|
@ -182,8 +182,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
size_t n_threads = out.size() * thread_group_size;
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
auto gd = get_2d_grid_dims(out.shape(), out.strides());
|
||||
MTL::Size grid_dims = MTL::Size(thread_group_size, gd.width, gd.height);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
|
Loading…
Reference in New Issue
Block a user