mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 19:48:15 +08:00
fix
This commit is contained in:
@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
|
|||||||
void** data_ptrs,
|
void** data_ptrs,
|
||||||
F&& execute) {
|
F&& execute) {
|
||||||
int workspace_size = plan.getWorkspaceSize();
|
int workspace_size = plan.getWorkspaceSize();
|
||||||
array workspace(
|
void* workspace_ptr = nullptr;
|
||||||
workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream())
|
if (workspace_size > 0) {
|
||||||
: allocator::Buffer(nullptr),
|
array workspace(
|
||||||
{workspace_size},
|
cu::malloc_async(workspace_size, encoder.stream()),
|
||||||
uint8);
|
{workspace_size},
|
||||||
|
uint8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
workspace_ptr = gpu_ptr<void>(workspace);
|
||||||
|
}
|
||||||
|
|
||||||
auto args = cudnn_frontend::VariantPackBuilder()
|
auto args = cudnn_frontend::VariantPackBuilder()
|
||||||
.setWorkspacePointer(gpu_ptr<void>(workspace))
|
.setWorkspacePointer(workspace_ptr)
|
||||||
.setDataPointers(num_args, data_ptrs)
|
.setDataPointers(num_args, data_ptrs)
|
||||||
.setUids(num_args, uids)
|
.setUids(num_args, uids)
|
||||||
.build();
|
.build();
|
||||||
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.add_temporary(workspace);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user