mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Test the native cuda graph api
This commit is contained in:
@@ -22,6 +22,9 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Not all engines support it so can not use this API now.
|
||||
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||
|
||||
struct ConvCacheKey {
|
||||
int device_id;
|
||||
cudnnBackendDescriptorType_t backend_type;
|
||||
@@ -181,6 +184,20 @@ bool execute_plan(
|
||||
.setUids(3, uids)
|
||||
.build();
|
||||
|
||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||
cudaGraph_t graph;
|
||||
cudaGraphCreate(&graph, 0);
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||
if (cudnnBackendPopulateCudaGraph(
|
||||
encoder.device().cudnn_handle(),
|
||||
plan.get_raw_desc(),
|
||||
variantPack.get_raw_desc(),
|
||||
graph) != CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
encoder.add_graph_node(graph);
|
||||
#else
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(
|
||||
encoder.device().cudnn_handle(),
|
||||
@@ -190,6 +207,7 @@ bool execute_plan(
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
|
||||
Reference in New Issue
Block a user