From 7bb96e42496dd0b65b93233b6f7ad9f1654a515a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 6 Aug 2025 06:18:58 -0700 Subject: [PATCH] fix cublas on h100 (#2466) --- mlx/backend/cuda/gemms/cublas_gemm.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index a0e936fd46..61f12ba1db 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -213,7 +213,7 @@ void Matmul::run_impl( matmul_desc_, a_desc_, b_desc_, - out_desc_, // TODO should that be c_desc is it's set? + c ? c_desc_ : out_desc_, out_desc_, pref_, 1, @@ -226,8 +226,10 @@ void Matmul::run_impl( void* workspace_ptr = nullptr; if (heuristic_.workspaceSize > 0) { + // Ensure workspace is 256-byte aligned + int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; array workspace( - allocator::malloc(heuristic_.workspaceSize), + allocator::malloc(nbytes), {static_cast(heuristic_.workspaceSize)}, int8); encoder.add_temporary(workspace);