Do error check for cublas handle

This commit is contained in:
Cheng
2025-07-22 00:28:43 +00:00
parent 4c0dc7745f
commit 48e796bb91
4 changed files with 13 additions and 13 deletions

View File

@@ -51,14 +51,14 @@ Device::Device(int device) : device_(device) {
} }
// The cublasLt handle is used by matmul. // The cublasLt handle is used by matmul.
make_current(); make_current();
cublasLtCreate(&lt_); CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
// The cudnn handle is used by Convolution. // The cudnn handle is used by Convolution.
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_)); CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
} }
Device::~Device() { Device::~Device() {
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_)); CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
cublasLtDestroy(lt_); CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
} }
void Device::make_current() { void Device::make_current() {

View File

@@ -8,7 +8,6 @@
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <cublasLt.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
@@ -18,16 +17,6 @@ namespace mlx::core {
namespace cu { namespace cu {
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
struct CublasPreference { struct CublasPreference {
CublasPreference(Device& device) { CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB

View File

@@ -17,6 +17,14 @@ CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
} }
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
void check_cuda_error(const char* name, cudaError_t err) { void check_cuda_error(const char* name, cudaError_t err) {
if (err != cudaSuccess) { if (err != cudaSuccess) {
throw std::runtime_error( throw std::runtime_error(

View File

@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <cublasLt.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
@@ -33,10 +34,12 @@ class CudaStream {
}; };
// Throw exception if the cuda API does not succeed. // Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err); void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed. // The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Convert Dtype to CUDA C++ types. // Convert Dtype to CUDA C++ types.