mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 15:59:27 +08:00
[CUDA] Remove thrust in arange (#2535)
This commit is contained in:
parent
f55b6f1f2f
commit
333ffea273
@ -6,23 +6,33 @@
|
|||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <thrust/device_ptr.h>
|
|
||||||
#include <thrust/transform.h>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
template <typename T>
|
namespace cg = cooperative_groups;
|
||||||
struct Arange {
|
|
||||||
const T start;
|
|
||||||
const T step;
|
|
||||||
|
|
||||||
__device__ T operator()(uint32_t i) const {
|
template <typename T, typename IdxT, int N_WRITES>
|
||||||
return start + i * step;
|
__global__ void arange(T* out, IdxT size, T start, T step) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_WRITES > size) {
|
||||||
|
for (IdxT i = index * N_WRITES; i < size; ++i) {
|
||||||
|
out[i] = start + i * step;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
AlignedVector<T, N_WRITES> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_WRITES; ++i) {
|
||||||
|
out_vec[i] = start + (index * N_WRITES + i) * step;
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_WRITES>(out, index, out_vec);
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
@ -36,19 +46,23 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& encoder = cu::get_command_encoder(stream());
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||||
using OutType = cuda_type_t<CTYPE>;
|
using OutType = cuda_type_t<CTYPE>;
|
||||||
CTYPE step =
|
constexpr int N_WRITES = 16 / sizeof(OutType);
|
||||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||||
thrust::transform(
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
cu::thrust_policy(encoder.stream()),
|
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
|
||||||
thrust::counting_iterator<uint32_t>(0),
|
encoder.add_kernel_node(
|
||||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
cu::arange<OutType, IdxT, N_WRITES>,
|
||||||
thrust::device_pointer_cast(out.data<OutType>()),
|
num_blocks,
|
||||||
cu::Arange<OutType>{
|
block_dims,
|
||||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
0,
|
||||||
|
out.data<OutType>(),
|
||||||
|
out.data_size(),
|
||||||
|
static_cast<CTYPE>(start_),
|
||||||
|
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <thrust/iterator/transform_iterator.h>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
@ -116,15 +115,4 @@ inline __host__ __device__ auto cast_to(SrcT x) {
|
|||||||
return CastOp<SrcT, DstT>{}(x);
|
return CastOp<SrcT, DstT>{}(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return an iterator that cast the value to DstT using CastOp.
|
|
||||||
template <typename DstT, typename Iterator>
|
|
||||||
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
|
|
||||||
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
|
||||||
if constexpr (std::is_same_v<SrcT, DstT>) {
|
|
||||||
return it;
|
|
||||||
} else {
|
|
||||||
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
Loading…
Reference in New Issue
Block a user