[CUDA] Switch to CUDA graphs (#2317)

* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
This commit is contained in:
Awni Hannun
2025-07-02 15:59:13 -07:00
committed by GitHub
parent e76e9b87f0
commit ec0d5db67b
36 changed files with 1461 additions and 1212 deletions

View File

@@ -26,16 +26,6 @@ void check_nvrtc_error(const char* name, nvrtcResult err) {
}
}
#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd))
void check_cu_error(const char* name, CUresult err) {
if (err != CUDA_SUCCESS) {
const char* err_str = "Unknown error";
cuGetErrorString(err, &err_str);
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
}
}
// Return the location of the CUDA toolkit.
const std::string& cuda_home() {
static std::string home = []() -> std::string {
@@ -280,60 +270,13 @@ JitModule::JitModule(
// Load kernels.
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel;
}
}
JitModule::~JitModule() {
CHECK_CU_ERROR(cuModuleUnload(module_));
}
void JitModule::launch_kernel(
CUstream stream,
const std::string& kernel_name,
const array& arr,
bool large,
int work_per_thread) {
CUfunction kernel = get_kernel(kernel_name);
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
int _, block_dim;
CHECK_CU_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
if (block_dim > nthreads) {
block_dim = nthreads;
}
Dims num_blocks{1, 1, 1};
if (large) {
num_blocks =
get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread);
std::get<0>(num_blocks) =
(std::get<0>(num_blocks) + block_dim - 1) / block_dim;
} else {
std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim;
}
launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1});
}
void JitModule::launch_kernel(
CUstream stream,
CUfunction kernel,
Dims num_blocks,
Dims block_dims) {
CHECK_CU_ERROR(cuLaunchKernel(
kernel,
std::get<0>(num_blocks),
std::get<1>(num_blocks),
std::get<2>(num_blocks),
std::get<0>(block_dims),
std::get<1>(block_dims),
std::get<2>(block_dims),
0,
stream,
args_.data(),
nullptr));
args_.clear();
storage_.clear();
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
@@ -345,10 +288,6 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
return it->second;
}
void JitModule::append_ptr_arg(const void* v) {
args_.push_back(const_cast<void*>(v));
}
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,