This commit is contained in:
Awni Hannun
2025-11-03 16:43:19 -08:00
parent cc6df9fc8a
commit 529842fed9

View File

@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream())
: allocator::Buffer(nullptr),
cu::malloc_async(workspace_size, encoder.stream()),
{workspace_size},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(gpu_ptr<void>(workspace))
.setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
return false;
}
encoder.add_temporary(workspace);
return true;
}