Remove thrust iterators (#2396)

This commit is contained in:
Cheng
2025-07-21 23:30:27 +09:00
committed by GitHub
parent 93d70419e7
commit f55c4ed1d6
7 changed files with 23 additions and 193 deletions

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
@@ -115,7 +115,7 @@ __global__ void arg_reduce_general(
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
}