diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 2d567af8d..41a583996 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -51,14 +51,14 @@ Device::Device(int device) : device_(device) { } // The cublasLt handle is used by matmul. make_current(); - cublasLtCreate(<_); + CHECK_CUBLAS_ERROR(cublasLtCreate(<_)); // The cudnn handle is used by Convolution. CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_)); } Device::~Device() { CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_)); - cublasLtDestroy(lt_); + CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_)); } void Device::make_current() { diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 1bca7c730..efddf2506 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -8,7 +8,6 @@ #include "mlx/primitives.h" #include "mlx/utils.h" -#include #include #include @@ -18,16 +17,6 @@ namespace mlx::core { namespace cu { -#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) - -void check_cublas_error(const char* name, cublasStatus_t err) { - if (err != CUBLAS_STATUS_SUCCESS) { - // TODO: Use cublasGetStatusString when it is widely available. - throw std::runtime_error( - fmt::format("{} failed with code: {}.", name, static_cast(err))); - } -} - struct CublasPreference { CublasPreference(Device& device) { // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 1c12fa4df..baab3b2a5 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -17,6 +17,14 @@ CudaStream::~CudaStream() { CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); } +void check_cublas_error(const char* name, cublasStatus_t err) { + if (err != CUBLAS_STATUS_SUCCESS) { + // TODO: Use cublasGetStatusString when it is widely available. + throw std::runtime_error( + fmt::format("{} failed with code: {}.", name, static_cast(err))); + } +} + void check_cuda_error(const char* name, cudaError_t err) { if (err != cudaSuccess) { throw std::runtime_error( diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index bfb02c5b6..f2b6b16cb 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include @@ -33,10 +34,12 @@ class CudaStream { }; // Throw exception if the cuda API does not succeed. +void check_cublas_error(const char* name, cublasStatus_t err); void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, CUresult err); // The macro version that prints the command that failed. +#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) // Convert Dtype to CUDA C++ types.