Set cudnn stream before execution

This commit is contained in:
Cheng
2025-07-20 05:24:48 -07:00
parent 67a5f7b2a8
commit 3d16cb5071
3 changed files with 7 additions and 11 deletions

View File

@@ -184,25 +184,25 @@ bool execute_plan(
.setUids(3, uids) .setUids(3, uids)
.build(); .build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API #if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph; cudaGraph_t graph;
cudaGraphCreate(&graph, 0); cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer( std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); }); &graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph( if (cudnnBackendPopulateCudaGraph(
encoder.device().cudnn_handle(), handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
plan.get_raw_desc(), CUDNN_STATUS_SUCCESS) {
variantPack.get_raw_desc(),
graph) != CUDNN_STATUS_SUCCESS) {
return false; return false;
} }
encoder.add_graph_node(graph); encoder.add_graph_node(graph);
#else #else
auto capture = encoder.capture_context(); auto capture = encoder.capture_context();
if (cudnnBackendExecute( if (cudnnBackendExecute(
encoder.device().cudnn_handle(), handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
plan.get_raw_desc(), CUDNN_STATUS_SUCCESS) {
variantPack.get_raw_desc()) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed. // Discard the captured graph when failed.
capture.discard = true; capture.discard = true;
return false; return false;

View File

@@ -198,7 +198,6 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
CHECK_CUDNN_ERROR(cudnnSetStream(d.cudnn_handle(), stream()));
} }
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) { void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {

View File

@@ -32,9 +32,6 @@ cuda_skip = {
"TestConvTranspose.test_torch_conv_transpose_3D", "TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding", "TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
"TestLayers.test_conv1d",
"TestLayers.test_conv2d",
"TestVmap.test_vmap_conv",
# FFTs NYI # FFTs NYI
"TestFFT.test_fft", "TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_big_powers_of_two",