[CUDA] Switch to CUDA graphs (#2317)

* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
This commit is contained in:
Awni Hannun
2025-07-02 15:59:13 -07:00
committed by GitHub
parent e76e9b87f0
commit ec0d5db67b
36 changed files with 1461 additions and 1212 deletions

View File

@@ -50,32 +50,6 @@ array swapaxes_in_eval(const array& in, int axis1, int axis2) {
return out;
}
template <typename... Args>
void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(
cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
temp.data<void>(), size, args...));
}
template <typename... Args>
void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(
cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
temp.data<void>(), size, args...));
}
struct OffsetTransform {
int nsort;
@@ -113,57 +87,94 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using Type = cuda_type_t<CTYPE>;
auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0), OffsetTransform{nsort});
if (argsort) {
// Indices in the sorted dimension.
array indices(
allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(indices);
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
auto& stream = encoder.stream();
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using Type = cuda_type_t<CTYPE>;
auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0), OffsetTransform{nsort});
if (argsort) {
// Indices in the sorted dimension.
array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(indices);
// In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it.
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
encoder.add_temporary(discard);
// In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it.
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
encoder.add_temporary(discard);
segmented_sort_pairs(
encoder,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream);
} else {
segmented_sort(
encoder,
in.data<Type>(),
out.data<Type>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream);
}
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
nullptr,
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
auto capture = encoder.capture_context();
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
temp.data<void>(),
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
} else {
throw std::runtime_error(
"CUDA backend does not support sorting complex numbers");
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
nullptr,
size,
in.data<Type>(),
out.data<Type>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
temp.data<void>(),
size,
in.data<Type>(),
out.data<Type>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
}
});
} else {
throw std::runtime_error(
"CUDA backend does not support sorting complex numbers");
}
});
if (!is_segmented_sort) {