Do error check for cublas handle

This commit is contained in:
Cheng
2025-07-22 00:28:43 +00:00
parent 4c0dc7745f
commit 48e796bb91
4 changed files with 13 additions and 13 deletions

View File

@@ -51,14 +51,14 @@ Device::Device(int device) : device_(device) {
}
// The cublasLt handle is used by matmul.
make_current();
cublasLtCreate(&lt_);
CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
// The cudnn handle is used by Convolution.
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
}
Device::~Device() {
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
cublasLtDestroy(lt_);
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
}
void Device::make_current() {

View File

@@ -8,7 +8,6 @@
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include <cublasLt.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
@@ -18,16 +17,6 @@ namespace mlx::core {
namespace cu {
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB

View File

@@ -17,6 +17,14 @@ CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
}
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
void check_cuda_error(const char* name, cudaError_t err) {
if (err != cudaSuccess) {
throw std::runtime_error(

View File

@@ -4,6 +4,7 @@
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
@@ -33,10 +34,12 @@ class CudaStream {
};
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Convert Dtype to CUDA C++ types.