fix cublas on h100

This commit is contained in:
Awni Hannun
2025-08-05 20:05:13 -07:00
parent fa89f0b150
commit 193dcb8553

View File

@@ -213,7 +213,7 @@ void Matmul::run_impl(
matmul_desc_,
a_desc_,
b_desc_,
out_desc_, // TODO should that be c_desc is it's set?
c ? c_desc_ : out_desc_,
out_desc_,
pref_,
1,
@@ -226,8 +226,10 @@ void Matmul::run_impl(
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(heuristic_.workspaceSize),
allocator::malloc(nbytes),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);