Add RAII managed CudaGraph class

This commit is contained in:
Cheng
2025-08-11 18:20:00 -07:00
parent fce53b61d6
commit 3ebc047168
5 changed files with 87 additions and 82 deletions

View File

@@ -193,10 +193,7 @@ bool execute_plan(
cudnnSetStream(handle, encoder.stream());
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph;
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
CudaGraph graph(encoder.device());
if (cudnnBackendPopulateCudaGraph(
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
CUDNN_STATUS_SUCCESS) {

View File

@@ -91,9 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
}
CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
graph.end_capture(enc.stream());
if (discard) {
return;
}
@@ -185,9 +183,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
}
CommandEncoder::CommandEncoder(Device& d)
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
: device_(d),
stream_(d),
graph_(d),
graph_cache_(cuda_graph_cache_size()) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
@@ -311,8 +310,7 @@ void CommandEncoder::commit() {
to_nodes_.clear();
graph_key_.clear();
node_map_.clear();
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
graph_ = CudaGraph(device_);
}
// Put completion handlers in a batch.

View File

@@ -21,7 +21,7 @@ class CommandEncoder {
struct CaptureContext {
CaptureContext(CommandEncoder& enc);
~CaptureContext();
cudaGraph_t graph;
CudaGraph graph;
CommandEncoder& enc;
bool discard{false};
};
@@ -115,7 +115,7 @@ class CommandEncoder {
Device& device_;
CudaStream stream_;
cudaGraph_t graph_;
CudaGraph graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};

View File

@@ -8,36 +8,6 @@
namespace mlx::core {
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
}
CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {}
CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) {
other.handle_ = nullptr;
};
CudaGraphExec::~CudaGraphExec() {
reset();
}
void CudaGraphExec::instantiate(cudaGraph_t graph) {
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
}
void CudaGraphExec::reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_));
handle_ = nullptr;
}
}
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
@@ -96,4 +66,24 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
}
}
CudaGraph::CudaGraph(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0));
}
void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
}
void CudaGraphExec::instantiate(cudaGraph_t graph) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
}
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking));
}
} // namespace mlx::core

View File

@@ -16,44 +16,6 @@ class Device;
struct Dtype;
// Cuda stream managed with RAII.
class CudaStream {
public:
explicit CudaStream(cu::Device& device);
~CudaStream();
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const {
return stream_;
}
private:
cudaStream_t stream_;
};
// Move-able RAII handle of cudaGraphExec_t.
class CudaGraphExec {
public:
CudaGraphExec(cudaGraphExec_t handle = nullptr);
CudaGraphExec(CudaGraphExec&& other);
~CudaGraphExec();
CudaGraphExec(const CudaGraphExec&) = delete;
CudaGraphExec& operator=(const CudaGraphExec&) = delete;
void instantiate(cudaGraph_t graph);
void reset();
operator cudaGraphExec_t() const {
return handle_;
}
private:
cudaGraphExec_t handle_;
};
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
@@ -66,4 +28,62 @@ void check_cuda_error(const char* name, CUresult err);
// Convert Dtype to CUDA C++ types.
const char* dtype_to_cuda_type(const Dtype& dtype);
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
} // namespace mlx::core