mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Set cudnn stream before execution
This commit is contained in:
@@ -184,25 +184,25 @@ bool execute_plan(
|
|||||||
.setUids(3, uids)
|
.setUids(3, uids)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
cudnnSetStream(handle, encoder.stream());
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||||
cudaGraph_t graph;
|
cudaGraph_t graph;
|
||||||
cudaGraphCreate(&graph, 0);
|
cudaGraphCreate(&graph, 0);
|
||||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||||
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||||
if (cudnnBackendPopulateCudaGraph(
|
if (cudnnBackendPopulateCudaGraph(
|
||||||
encoder.device().cudnn_handle(),
|
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||||
plan.get_raw_desc(),
|
CUDNN_STATUS_SUCCESS) {
|
||||||
variantPack.get_raw_desc(),
|
|
||||||
graph) != CUDNN_STATUS_SUCCESS) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
encoder.add_graph_node(graph);
|
encoder.add_graph_node(graph);
|
||||||
#else
|
#else
|
||||||
auto capture = encoder.capture_context();
|
auto capture = encoder.capture_context();
|
||||||
if (cudnnBackendExecute(
|
if (cudnnBackendExecute(
|
||||||
encoder.device().cudnn_handle(),
|
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||||
plan.get_raw_desc(),
|
CUDNN_STATUS_SUCCESS) {
|
||||||
variantPack.get_raw_desc()) != CUDNN_STATUS_SUCCESS) {
|
|
||||||
// Discard the captured graph when failed.
|
// Discard the captured graph when failed.
|
||||||
capture.discard = true;
|
capture.discard = true;
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -198,7 +198,6 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
|||||||
|
|
||||||
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
|
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||||
CHECK_CUDNN_ERROR(cudnnSetStream(d.cudnn_handle(), stream()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
|
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
|
||||||
|
|||||||
@@ -32,9 +32,6 @@ cuda_skip = {
|
|||||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||||
"TestLayers.test_conv1d",
|
|
||||||
"TestLayers.test_conv2d",
|
|
||||||
"TestVmap.test_vmap_conv",
|
|
||||||
# FFTs NYI
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
|
|||||||
Reference in New Issue
Block a user