Set cudnn stream before execution

This commit is contained in:
Cheng
2025-07-20 05:24:48 -07:00
parent 67a5f7b2a8
commit 3d16cb5071
3 changed files with 7 additions and 11 deletions

View File

@@ -184,25 +184,25 @@ bool execute_plan(
.setUids(3, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph;
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph(
encoder.device().cudnn_handle(),
plan.get_raw_desc(),
variantPack.get_raw_desc(),
graph) != CUDNN_STATUS_SUCCESS) {
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
encoder.add_graph_node(graph);
#else
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
encoder.device().cudnn_handle(),
plan.get_raw_desc(),
variantPack.get_raw_desc()) != CUDNN_STATUS_SUCCESS) {
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;

View File

@@ -198,7 +198,6 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
CHECK_CUDNN_ERROR(cudnnSetStream(d.cudnn_handle(), stream()));
}
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {

View File

@@ -32,9 +32,6 @@ cuda_skip = {
"TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
"TestLayers.test_conv1d",
"TestLayers.test_conv2d",
"TestVmap.test_vmap_conv",
# FFTs NYI
"TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two",