mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Test the native cuda graph api
This commit is contained in:
@@ -22,6 +22,9 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Not all engines support it so can not use this API now.
|
||||||
|
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||||
|
|
||||||
struct ConvCacheKey {
|
struct ConvCacheKey {
|
||||||
int device_id;
|
int device_id;
|
||||||
cudnnBackendDescriptorType_t backend_type;
|
cudnnBackendDescriptorType_t backend_type;
|
||||||
@@ -181,6 +184,20 @@ bool execute_plan(
|
|||||||
.setUids(3, uids)
|
.setUids(3, uids)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||||
|
cudaGraph_t graph;
|
||||||
|
cudaGraphCreate(&graph, 0);
|
||||||
|
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||||
|
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||||
|
if (cudnnBackendPopulateCudaGraph(
|
||||||
|
encoder.device().cudnn_handle(),
|
||||||
|
plan.get_raw_desc(),
|
||||||
|
variantPack.get_raw_desc(),
|
||||||
|
graph) != CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
encoder.add_graph_node(graph);
|
||||||
|
#else
|
||||||
auto capture = encoder.capture_context();
|
auto capture = encoder.capture_context();
|
||||||
if (cudnnBackendExecute(
|
if (cudnnBackendExecute(
|
||||||
encoder.device().cudnn_handle(),
|
encoder.device().cudnn_handle(),
|
||||||
@@ -190,6 +207,7 @@ bool execute_plan(
|
|||||||
capture.discard = true;
|
capture.discard = true;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
Reference in New Issue
Block a user