mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 19:48:15 +08:00
fix
This commit is contained in:
@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
|
||||
void** data_ptrs,
|
||||
F&& execute) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream())
|
||||
: allocator::Buffer(nullptr),
|
||||
cu::malloc_async(workspace_size, encoder.stream()),
|
||||
{workspace_size},
|
||||
uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
auto args = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(gpu_ptr<void>(workspace))
|
||||
.setWorkspacePointer(workspace_ptr)
|
||||
.setDataPointers(num_args, data_ptrs)
|
||||
.setUids(num_args, uids)
|
||||
.build();
|
||||
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
|
||||
return false;
|
||||
}
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user