mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 00:08:09 +08:00
fix cublas on h100 (#2466)
This commit is contained in:
@@ -213,7 +213,7 @@ void Matmul::run_impl(
|
|||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
a_desc_,
|
a_desc_,
|
||||||
b_desc_,
|
b_desc_,
|
||||||
out_desc_, // TODO should that be c_desc is it's set?
|
c ? c_desc_ : out_desc_,
|
||||||
out_desc_,
|
out_desc_,
|
||||||
pref_,
|
pref_,
|
||||||
1,
|
1,
|
||||||
@@ -226,8 +226,10 @@ void Matmul::run_impl(
|
|||||||
|
|
||||||
void* workspace_ptr = nullptr;
|
void* workspace_ptr = nullptr;
|
||||||
if (heuristic_.workspaceSize > 0) {
|
if (heuristic_.workspaceSize > 0) {
|
||||||
|
// Ensure workspace is 256-byte aligned
|
||||||
|
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
||||||
array workspace(
|
array workspace(
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
allocator::malloc(nbytes),
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
int8);
|
int8);
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
|
Reference in New Issue
Block a user