Files
mlx/mlx/backend/cuda/copy/copy_contiguous.cu
Awni Hannun df58b4133a
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
[CUDA] Reduce use of managed memory (#2725)
* Use async cuda malloc managed with cuda 13

* add pool threshold

* refactor for regular cuda malloc

* load eval gpu for cuda

* remove use of cuda pool, use cuda free async

* fix

* fix

* fix

* fix

* fix + comment
2025-11-05 16:05:23 -08:00

89 lines
2.4 KiB
Plaintext

// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/copy/copy.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_s(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = cast_to<Out>(in[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = cast_to<Out>(in[0]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = cast_to<Out>(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = cast_to<Out>(in_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
} // namespace cu
void copy_contiguous(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t in_offset,
int64_t out_offset) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(InType);
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
gpu_ptr<InType>(in) + in_offset,
gpu_ptr<OutType>(out) + out_offset,
out.data_size());
});
});
});
}
} // namespace mlx::core