// Copyright © 2023 Apple Inc. #include #include "mlx/backend/metal/kernels/utils.h" using namespace metal; template struct IndexValPair { uint32_t index; U val; }; template struct ArgMin { static constexpr constant U init = Limits::max; IndexValPair reduce(IndexValPair best, IndexValPair current) { if (best.val > current.val || (best.val == current.val && best.index > current.index)) { return current; } else { return best; } } template IndexValPair reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { for (int i = 0; i < N; i++) { if (vals[i] < best.val) { best.val = vals[i]; best.index = offset + i; } } return best; } }; template struct ArgMax { static constexpr constant U init = Limits::min; IndexValPair reduce(IndexValPair best, IndexValPair current) { if (best.val < current.val || (best.val == current.val && best.index > current.index)) { return current; } else { return best; } } template IndexValPair reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { for (int i = 0; i < N; i++) { if (vals[i] > best.val) { best.val = vals[i]; best.index = offset + i; } } return best; } }; template IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { return IndexValPair{ simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)}; } template [[kernel]] void arg_reduce_general( const device T* in [[buffer(0)]], device uint32_t* out [[buffer(1)]], const constant int* shape [[buffer(2)]], const constant int64_t* in_strides [[buffer(3)]], const constant int64_t* out_strides [[buffer(4)]], const constant size_t& ndim [[buffer(5)]], const constant int64_t& axis_stride [[buffer(6)]], const constant size_t& axis_size [[buffer(7)]], uint3 gid [[thread_position_in_grid]], uint3 gsize [[threads_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]], uint simd_size [[threads_per_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { // Shapes and strides *do not* contain the reduction axis. The reduction size // and stride are provided in axis_stride and axis_size. // // Note: in shape == out shape with this convention. // // The sketch of the kernel is as follows. // 1. Launch prod(shape) * thread_group_size threads. // 2. Loop ceildiv(axis_size / lsize) times // 3. Read input values // 4. Reduce among them and go to 3 // 4. Reduce in each simd_group // 6. Write in the thread local memory // 6. Reduce them across thread group // 7. Write the output without need for atomic Op op; // Compute the input/output index. There is one beginning and one output for // the whole threadgroup. int64_t row_idx = gid.y + static_cast(gsize.y) * gid.z; auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim); auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim); IndexValPair best{0, Op::init}; threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Read the current value uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS; uint32_t offset = current_index; const device T* current_in = in + in_idx + current_index * axis_stride; T vals[N_READS]; for (int i = 0; i < N_READS; i++) { vals[i] = (current_index < axis_size) ? *current_in : T(Op::init); current_index++; current_in += axis_stride; } best = op.template reduce_many(best, vals, offset); } // At this point we have reduced the axis into thread group best values so we // need to reduce across the thread group. // First per simd reduction. for (uint offset = simd_size / 2; offset > 0; offset /= 2) { IndexValPair neighbor = simd_shuffle_down(best, offset); best = op.reduce(best, neighbor); } // Write to the threadgroup memory if (simd_lane_id == 0) { local_data[simd_group_id] = best; } threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_group_id != 0) { return; } // Read the appropriate value from local data and perform one simd reduction uint simd_groups = ceildiv(lsize.x, simd_size); if (simd_lane_id < simd_groups) { best = local_data[simd_lane_id]; } for (uint offset = simd_size / 2; offset > 0; offset /= 2) { IndexValPair neighbor = simd_shuffle_down(best, offset); best = op.reduce(best, neighbor); } // Finally write the output if (lid.x == 0) { out[out_idx] = best.index; } } // clang-format off #define instantiate_arg_reduce(name, itype) \ instantiate_kernel( \ "argmin_" #name, arg_reduce_general, itype, ArgMin) \ instantiate_kernel( \ "argmax_" #name, arg_reduce_general, itype, ArgMax) instantiate_arg_reduce(bool_, bool) instantiate_arg_reduce(uint8, uint8_t) instantiate_arg_reduce(uint16, uint16_t) instantiate_arg_reduce(uint32, uint32_t) instantiate_arg_reduce(uint64, uint64_t) instantiate_arg_reduce(int8, int8_t) instantiate_arg_reduce(int16, int16_t) instantiate_arg_reduce(int32, int32_t) instantiate_arg_reduce(int64, int64_t) instantiate_arg_reduce(float16, half) instantiate_arg_reduce(float32, float) instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on