fix cublas on h100 (#2466)

This commit is contained in:
Awni Hannun 2025-08-06 06:18:58 -07:00 committed by GitHub
parent fa89f0b150
commit 7bb96e4249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -213,7 +213,7 @@ void Matmul::run_impl(
matmul_desc_, matmul_desc_,
a_desc_, a_desc_,
b_desc_, b_desc_,
out_desc_, // TODO should that be c_desc is it's set? c ? c_desc_ : out_desc_,
out_desc_, out_desc_,
pref_, pref_,
1, 1,
@ -226,8 +226,10 @@ void Matmul::run_impl(
void* workspace_ptr = nullptr; void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) { if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace( array workspace(
allocator::malloc(heuristic_.workspaceSize), allocator::malloc(nbytes),
{static_cast<int>(heuristic_.workspaceSize)}, {static_cast<int>(heuristic_.workspaceSize)},
int8); int8);
encoder.add_temporary(workspace); encoder.add_temporary(workspace);