Change name of argument

This commit is contained in:
Angelos Katharopoulos
2025-06-29 04:11:09 -07:00
parent ef813b6d13
commit aa7841814c

View File

@@ -156,22 +156,15 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4;
dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim_constant) {
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{block_dim_constant(), 1, 1};
auto kernel = cu::arg_reduce_general<
T,
cu::ArgMax<T>,
block_dim_constant(),
N_READS>;
auto kernel =
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) {
kernel = cu::arg_reduce_general<
T,
cu::ArgMin<T>,
block_dim_constant(),
N_READS>;
kernel = cu::
arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
kernel<<<num_blocks, block_dim(), 0, stream>>>(
in.data<T>(),
out.data<uint32_t>(),
out.size(),