mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add RAII managed CudaGraph class
This commit is contained in:
@@ -190,10 +190,7 @@ bool execute_plan(
|
|||||||
cudnnSetStream(handle, encoder.stream());
|
cudnnSetStream(handle, encoder.stream());
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||||
cudaGraph_t graph;
|
CudaGraph graph(encoder.device());
|
||||||
cudaGraphCreate(&graph, 0);
|
|
||||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
|
||||||
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
|
||||||
if (cudnnBackendPopulateCudaGraph(
|
if (cudnnBackendPopulateCudaGraph(
|
||||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||||
CUDNN_STATUS_SUCCESS) {
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
|||||||
@@ -91,9 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
graph.end_capture(enc.stream());
|
||||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
|
||||||
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
|
||||||
if (discard) {
|
if (discard) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -185,9 +183,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CommandEncoder(Device& d)
|
CommandEncoder::CommandEncoder(Device& d)
|
||||||
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
|
: device_(d),
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
stream_(d),
|
||||||
}
|
graph_(d),
|
||||||
|
graph_cache_(cuda_graph_cache_size()) {}
|
||||||
|
|
||||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||||
worker_.add_task(std::move(task));
|
worker_.add_task(std::move(task));
|
||||||
@@ -311,8 +310,7 @@ void CommandEncoder::commit() {
|
|||||||
to_nodes_.clear();
|
to_nodes_.clear();
|
||||||
graph_key_.clear();
|
graph_key_.clear();
|
||||||
node_map_.clear();
|
node_map_.clear();
|
||||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
|
graph_ = CudaGraph(device_);
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put completion handlers in a batch.
|
// Put completion handlers in a batch.
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class CommandEncoder {
|
|||||||
struct CaptureContext {
|
struct CaptureContext {
|
||||||
CaptureContext(CommandEncoder& enc);
|
CaptureContext(CommandEncoder& enc);
|
||||||
~CaptureContext();
|
~CaptureContext();
|
||||||
cudaGraph_t graph;
|
CudaGraph graph;
|
||||||
CommandEncoder& enc;
|
CommandEncoder& enc;
|
||||||
bool discard{false};
|
bool discard{false};
|
||||||
};
|
};
|
||||||
@@ -115,7 +115,7 @@ class CommandEncoder {
|
|||||||
|
|
||||||
Device& device_;
|
Device& device_;
|
||||||
CudaStream stream_;
|
CudaStream stream_;
|
||||||
cudaGraph_t graph_;
|
CudaGraph graph_;
|
||||||
Worker worker_;
|
Worker worker_;
|
||||||
char node_count_{0};
|
char node_count_{0};
|
||||||
char graph_node_count_{0};
|
char graph_node_count_{0};
|
||||||
|
|||||||
@@ -8,36 +8,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
CudaStream::CudaStream(cu::Device& device) {
|
|
||||||
device.make_current();
|
|
||||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaStream::~CudaStream() {
|
|
||||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {}
|
|
||||||
|
|
||||||
CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) {
|
|
||||||
other.handle_ = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
CudaGraphExec::~CudaGraphExec() {
|
|
||||||
reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaGraphExec::instantiate(cudaGraph_t graph) {
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaGraphExec::reset() {
|
|
||||||
if (handle_ != nullptr) {
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_));
|
|
||||||
handle_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||||
// TODO: Use cublasGetStatusString when it is widely available.
|
// TODO: Use cublasGetStatusString when it is widely available.
|
||||||
@@ -96,4 +66,24 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CudaGraph::CudaGraph(cu::Device& device) {
|
||||||
|
device.make_current();
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaGraph::end_capture(cudaStream_t stream) {
|
||||||
|
assert(handle_ == nullptr);
|
||||||
|
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaGraphExec::instantiate(cudaGraph_t graph) {
|
||||||
|
assert(handle_ == nullptr);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaStream::CudaStream(cu::Device& device) {
|
||||||
|
device.make_current();
|
||||||
|
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -16,44 +16,6 @@ class Device;
|
|||||||
|
|
||||||
struct Dtype;
|
struct Dtype;
|
||||||
|
|
||||||
// Cuda stream managed with RAII.
|
|
||||||
class CudaStream {
|
|
||||||
public:
|
|
||||||
explicit CudaStream(cu::Device& device);
|
|
||||||
~CudaStream();
|
|
||||||
|
|
||||||
CudaStream(const CudaStream&) = delete;
|
|
||||||
CudaStream& operator=(const CudaStream&) = delete;
|
|
||||||
|
|
||||||
operator cudaStream_t() const {
|
|
||||||
return stream_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
cudaStream_t stream_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Move-able RAII handle of cudaGraphExec_t.
|
|
||||||
class CudaGraphExec {
|
|
||||||
public:
|
|
||||||
CudaGraphExec(cudaGraphExec_t handle = nullptr);
|
|
||||||
CudaGraphExec(CudaGraphExec&& other);
|
|
||||||
~CudaGraphExec();
|
|
||||||
|
|
||||||
CudaGraphExec(const CudaGraphExec&) = delete;
|
|
||||||
CudaGraphExec& operator=(const CudaGraphExec&) = delete;
|
|
||||||
|
|
||||||
void instantiate(cudaGraph_t graph);
|
|
||||||
void reset();
|
|
||||||
|
|
||||||
operator cudaGraphExec_t() const {
|
|
||||||
return handle_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
cudaGraphExec_t handle_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Throw exception if the cuda API does not succeed.
|
// Throw exception if the cuda API does not succeed.
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||||
void check_cuda_error(const char* name, cudaError_t err);
|
void check_cuda_error(const char* name, cudaError_t err);
|
||||||
@@ -66,4 +28,62 @@ void check_cuda_error(const char* name, CUresult err);
|
|||||||
// Convert Dtype to CUDA C++ types.
|
// Convert Dtype to CUDA C++ types.
|
||||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||||
|
|
||||||
|
// Base class for RAII managed CUDA resources.
|
||||||
|
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||||
|
class CudaHandle {
|
||||||
|
public:
|
||||||
|
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
||||||
|
|
||||||
|
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
||||||
|
assert(this != &other);
|
||||||
|
other.handle_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
~CudaHandle() {
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaHandle(const CudaHandle&) = delete;
|
||||||
|
CudaHandle& operator=(const CudaHandle&) = delete;
|
||||||
|
|
||||||
|
CudaHandle& operator=(CudaHandle&& other) {
|
||||||
|
assert(this != &other);
|
||||||
|
reset();
|
||||||
|
std::swap(handle_, other.handle_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
if (handle_ != nullptr) {
|
||||||
|
CHECK_CUDA_ERROR(Destroy(handle_));
|
||||||
|
handle_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
operator Handle() const {
|
||||||
|
return handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Handle handle_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrappers of CUDA resources.
|
||||||
|
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
||||||
|
public:
|
||||||
|
using CudaHandle::CudaHandle;
|
||||||
|
explicit CudaGraph(cu::Device& device);
|
||||||
|
void end_capture(cudaStream_t stream);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
||||||
|
public:
|
||||||
|
void instantiate(cudaGraph_t graph);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||||
|
public:
|
||||||
|
explicit CudaStream(cu::Device& device);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
Reference in New Issue
Block a user