mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Do error check for cublas handle
This commit is contained in:
@@ -51,14 +51,14 @@ Device::Device(int device) : device_(device) {
|
|||||||
}
|
}
|
||||||
// The cublasLt handle is used by matmul.
|
// The cublasLt handle is used by matmul.
|
||||||
make_current();
|
make_current();
|
||||||
cublasLtCreate(<_);
|
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||||
// The cudnn handle is used by Convolution.
|
// The cudnn handle is used by Convolution.
|
||||||
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
|
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
|
||||||
cublasLtDestroy(lt_);
|
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::make_current() {
|
void Device::make_current() {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@
|
|||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cublasLt.h>
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
@@ -18,16 +17,6 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
|
||||||
|
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
|
||||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
|
||||||
// TODO: Use cublasGetStatusString when it is widely available.
|
|
||||||
throw std::runtime_error(
|
|
||||||
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CublasPreference {
|
struct CublasPreference {
|
||||||
CublasPreference(Device& device) {
|
CublasPreference(Device& device) {
|
||||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
|
|||||||
@@ -17,6 +17,14 @@ CudaStream::~CudaStream() {
|
|||||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||||
|
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||||
|
// TODO: Use cublasGetStatusString when it is widely available.
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void check_cuda_error(const char* name, cudaError_t err) {
|
void check_cuda_error(const char* name, cudaError_t err) {
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
@@ -33,10 +34,12 @@ class CudaStream {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Throw exception if the cuda API does not succeed.
|
// Throw exception if the cuda API does not succeed.
|
||||||
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||||
void check_cuda_error(const char* name, cudaError_t err);
|
void check_cuda_error(const char* name, cudaError_t err);
|
||||||
void check_cuda_error(const char* name, CUresult err);
|
void check_cuda_error(const char* name, CUresult err);
|
||||||
|
|
||||||
// The macro version that prints the command that failed.
|
// The macro version that prints the command that failed.
|
||||||
|
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||||
|
|
||||||
// Convert Dtype to CUDA C++ types.
|
// Convert Dtype to CUDA C++ types.
|
||||||
|
|||||||
Reference in New Issue
Block a user