fix cublas on h100 (#2466)

This commit is contained in:
Awni Hannun 2025-08-06 06:18:58 -07:00 committed by GitHub
parent fa89f0b150
commit 7bb96e4249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -213,7 +213,7 @@ void Matmul::run_impl(
matmul_desc_,
a_desc_,
b_desc_,
out_desc_, // TODO should that be c_desc is it's set?
c ? c_desc_ : out_desc_,
out_desc_,
pref_,
1,
@ -226,8 +226,10 @@ void Matmul::run_impl(
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(heuristic_.workspaceSize),
allocator::malloc(nbytes),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);