mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
MLX_SWITCH macros to templates (#2320)
This commit is contained in:
committed by
GitHub
parent
33bf1a244b
commit
3d5e17e507
@@ -76,6 +76,14 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct OffsetTransform {
|
||||
int nsort;
|
||||
|
||||
int __device__ operator()(int i) {
|
||||
return i * nsort;
|
||||
}
|
||||
};
|
||||
|
||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
array out = out_;
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
@@ -106,12 +114,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||
using Type = cuda_type_t<CTYPE>;
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0),
|
||||
[nsort] __device__(int i) { return i * nsort; });
|
||||
thrust::make_counting_iterator(0), OffsetTransform{nsort});
|
||||
if (argsort) {
|
||||
// Indices in the sorted dimension.
|
||||
array indices(
|
||||
|
||||
Reference in New Issue
Block a user