mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 19:56:40 +08:00
fix cublas on h100 (#2466)
This commit is contained in:
parent
fa89f0b150
commit
7bb96e4249
@ -213,7 +213,7 @@ void Matmul::run_impl(
|
||||
matmul_desc_,
|
||||
a_desc_,
|
||||
b_desc_,
|
||||
out_desc_, // TODO should that be c_desc is it's set?
|
||||
c ? c_desc_ : out_desc_,
|
||||
out_desc_,
|
||||
pref_,
|
||||
1,
|
||||
@ -226,8 +226,10 @@ void Matmul::run_impl(
|
||||
|
||||
void* workspace_ptr = nullptr;
|
||||
if (heuristic_.workspaceSize > 0) {
|
||||
// Ensure workspace is 256-byte aligned
|
||||
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
||||
array workspace(
|
||||
allocator::malloc(heuristic_.workspaceSize),
|
||||
allocator::malloc(nbytes),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
|
Loading…
Reference in New Issue
Block a user