This commit is contained in:
Awni Hannun
2025-11-03 16:43:19 -08:00
parent cc6df9fc8a
commit 529842fed9

View File

@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
void** data_ptrs, void** data_ptrs,
F&& execute) { F&& execute) {
int workspace_size = plan.getWorkspaceSize(); int workspace_size = plan.getWorkspaceSize();
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace( array workspace(
workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream()) cu::malloc_async(workspace_size, encoder.stream()),
: allocator::Buffer(nullptr),
{workspace_size}, {workspace_size},
uint8); uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder() auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(gpu_ptr<void>(workspace)) .setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs) .setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids) .setUids(num_args, uids)
.build(); .build();
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
return false; return false;
} }
encoder.add_temporary(workspace);
return true; return true;
} }