mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Set cudnn stream before execution
This commit is contained in:
@@ -184,25 +184,25 @@ bool execute_plan(
|
||||
.setUids(3, uids)
|
||||
.build();
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||
cudaGraph_t graph;
|
||||
cudaGraphCreate(&graph, 0);
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||
if (cudnnBackendPopulateCudaGraph(
|
||||
encoder.device().cudnn_handle(),
|
||||
plan.get_raw_desc(),
|
||||
variantPack.get_raw_desc(),
|
||||
graph) != CUDNN_STATUS_SUCCESS) {
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
encoder.add_graph_node(graph);
|
||||
#else
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(
|
||||
encoder.device().cudnn_handle(),
|
||||
plan.get_raw_desc(),
|
||||
variantPack.get_raw_desc()) != CUDNN_STATUS_SUCCESS) {
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
|
@@ -198,7 +198,6 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
|
||||
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
CHECK_CUDNN_ERROR(cudnnSetStream(d.cudnn_handle(), stream()));
|
||||
}
|
||||
|
||||
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
|
||||
|
@@ -32,9 +32,6 @@ cuda_skip = {
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||
"TestLayers.test_conv1d",
|
||||
"TestLayers.test_conv2d",
|
||||
"TestVmap.test_vmap_conv",
|
||||
# FFTs NYI
|
||||
"TestFFT.test_fft",
|
||||
"TestFFT.test_fft_big_powers_of_two",
|
||||
|
Reference in New Issue
Block a user