mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 01:51:18 +08:00
fix cuda arg reduce
This commit is contained in:
parent
a6d780154f
commit
c353af5998
@ -1,5 +1,4 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
@ -113,7 +112,7 @@ __global__ void arg_reduce_general(
|
||||
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T vals[N_READS];
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().z;
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||
cub::LoadDirectBlocked(
|
||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||
best = op.reduce_many(best, vals, tid * N_READS);
|
||||
@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims{1, 1, BLOCK_DIM};
|
||||
dim3 block_dims{BLOCK_DIM, 1, 1};
|
||||
auto kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMax<InType>,
|
||||
|
Loading…
Reference in New Issue
Block a user