mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Change name of argument
This commit is contained in:
@@ -156,22 +156,15 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim_constant) {
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims{block_dim_constant(), 1, 1};
|
||||
auto kernel = cu::arg_reduce_general<
|
||||
T,
|
||||
cu::ArgMax<T>,
|
||||
block_dim_constant(),
|
||||
N_READS>;
|
||||
auto kernel =
|
||||
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = cu::arg_reduce_general<
|
||||
T,
|
||||
cu::ArgMin<T>,
|
||||
block_dim_constant(),
|
||||
N_READS>;
|
||||
kernel = cu::
|
||||
arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
kernel<<<num_blocks, block_dim(), 0, stream>>>(
|
||||
in.data<T>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
|
||||
Reference in New Issue
Block a user