Add RAII managed CudaGraph class

This commit is contained in:
Cheng
2025-08-11 18:20:00 -07:00
parent c422050ca7
commit 290b45eba3
5 changed files with 87 additions and 82 deletions

View File

@@ -190,10 +190,7 @@ bool execute_plan(
cudnnSetStream(handle, encoder.stream()); cudnnSetStream(handle, encoder.stream());
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API #if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph; CudaGraph graph(encoder.device());
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph( if (cudnnBackendPopulateCudaGraph(
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) != handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
CUDNN_STATUS_SUCCESS) { CUDNN_STATUS_SUCCESS) {

View File

@@ -91,9 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
} }
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); graph.end_capture(enc.stream());
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
if (discard) { if (discard) {
return; return;
} }
@@ -185,9 +183,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
} }
CommandEncoder::CommandEncoder(Device& d) CommandEncoder::CommandEncoder(Device& d)
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) { : device_(d),
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); stream_(d),
} graph_(d),
graph_cache_(cuda_graph_cache_size()) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) { void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task)); worker_.add_task(std::move(task));
@@ -311,8 +310,7 @@ void CommandEncoder::commit() {
to_nodes_.clear(); to_nodes_.clear();
graph_key_.clear(); graph_key_.clear();
node_map_.clear(); node_map_.clear();
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_)); graph_ = CudaGraph(device_);
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
} }
// Put completion handlers in a batch. // Put completion handlers in a batch.

View File

@@ -21,7 +21,7 @@ class CommandEncoder {
struct CaptureContext { struct CaptureContext {
CaptureContext(CommandEncoder& enc); CaptureContext(CommandEncoder& enc);
~CaptureContext(); ~CaptureContext();
cudaGraph_t graph; CudaGraph graph;
CommandEncoder& enc; CommandEncoder& enc;
bool discard{false}; bool discard{false};
}; };
@@ -115,7 +115,7 @@ class CommandEncoder {
Device& device_; Device& device_;
CudaStream stream_; CudaStream stream_;
cudaGraph_t graph_; CudaGraph graph_;
Worker worker_; Worker worker_;
char node_count_{0}; char node_count_{0};
char graph_node_count_{0}; char graph_node_count_{0};

View File

@@ -8,36 +8,6 @@
namespace mlx::core { namespace mlx::core {
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
}
CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {}
CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) {
other.handle_ = nullptr;
};
CudaGraphExec::~CudaGraphExec() {
reset();
}
void CudaGraphExec::instantiate(cudaGraph_t graph) {
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
}
void CudaGraphExec::reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_));
handle_ = nullptr;
}
}
void check_cublas_error(const char* name, cublasStatus_t err) { void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) { if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available. // TODO: Use cublasGetStatusString when it is widely available.
@@ -96,4 +66,24 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
} }
} }
CudaGraph::CudaGraph(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0));
}
void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
}
void CudaGraphExec::instantiate(cudaGraph_t graph) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
}
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking));
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -16,44 +16,6 @@ class Device;
struct Dtype; struct Dtype;
// Cuda stream managed with RAII.
class CudaStream {
public:
explicit CudaStream(cu::Device& device);
~CudaStream();
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const {
return stream_;
}
private:
cudaStream_t stream_;
};
// Move-able RAII handle of cudaGraphExec_t.
class CudaGraphExec {
public:
CudaGraphExec(cudaGraphExec_t handle = nullptr);
CudaGraphExec(CudaGraphExec&& other);
~CudaGraphExec();
CudaGraphExec(const CudaGraphExec&) = delete;
CudaGraphExec& operator=(const CudaGraphExec&) = delete;
void instantiate(cudaGraph_t graph);
void reset();
operator cudaGraphExec_t() const {
return handle_;
}
private:
cudaGraphExec_t handle_;
};
// Throw exception if the cuda API does not succeed. // Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err); void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, cudaError_t err);
@@ -66,4 +28,62 @@ void check_cuda_error(const char* name, CUresult err);
// Convert Dtype to CUDA C++ types. // Convert Dtype to CUDA C++ types.
const char* dtype_to_cuda_type(const Dtype& dtype); const char* dtype_to_cuda_type(const Dtype& dtype);
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
} // namespace mlx::core } // namespace mlx::core