mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Change name of argument
This commit is contained in:
@@ -156,22 +156,15 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
constexpr uint32_t N_READS = 4;
|
constexpr uint32_t N_READS = 4;
|
||||||
dispatch_block_dim(
|
dispatch_block_dim(
|
||||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim_constant) {
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
dim3 block_dims{block_dim_constant(), 1, 1};
|
auto kernel =
|
||||||
auto kernel = cu::arg_reduce_general<
|
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
|
||||||
T,
|
|
||||||
cu::ArgMax<T>,
|
|
||||||
block_dim_constant(),
|
|
||||||
N_READS>;
|
|
||||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||||
kernel = cu::arg_reduce_general<
|
kernel = cu::
|
||||||
T,
|
arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
|
||||||
cu::ArgMin<T>,
|
|
||||||
block_dim_constant(),
|
|
||||||
N_READS>;
|
|
||||||
}
|
}
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dim(), 0, stream>>>(
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<uint32_t>(),
|
out.data<uint32_t>(),
|
||||||
out.size(),
|
out.size(),
|
||||||
|
|||||||
Reference in New Issue
Block a user