mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 19:56:40 +08:00
fix cublas on h100 (#2466)
This commit is contained in:
parent
fa89f0b150
commit
7bb96e4249
@ -213,7 +213,7 @@ void Matmul::run_impl(
|
|||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
a_desc_,
|
a_desc_,
|
||||||
b_desc_,
|
b_desc_,
|
||||||
out_desc_, // TODO should that be c_desc is it's set?
|
c ? c_desc_ : out_desc_,
|
||||||
out_desc_,
|
out_desc_,
|
||||||
pref_,
|
pref_,
|
||||||
1,
|
1,
|
||||||
@ -226,8 +226,10 @@ void Matmul::run_impl(
|
|||||||
|
|
||||||
void* workspace_ptr = nullptr;
|
void* workspace_ptr = nullptr;
|
||||||
if (heuristic_.workspaceSize > 0) {
|
if (heuristic_.workspaceSize > 0) {
|
||||||
|
// Ensure workspace is 256-byte aligned
|
||||||
|
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
||||||
array workspace(
|
array workspace(
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
allocator::malloc(nbytes),
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
int8);
|
int8);
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
|
Loading…
Reference in New Issue
Block a user