mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Use async cuda malloc managed with cuda 13 * add pool threshold * refactor for regular cuda malloc * load eval gpu for cuda * remove use of cuda pool, use cuda free async * fix * fix * fix * fix * fix + comment
83 lines
1.9 KiB
C++
83 lines
1.9 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <cublasLt.h>
|
|
#include <cuda.h>
|
|
#include <cuda_runtime.h>
|
|
|
|
namespace mlx::core {
|
|
|
|
// Throw exception if the cuda API does not succeed.
|
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
|
void check_cuda_error(const char* name, cudaError_t err);
|
|
void check_cuda_error(const char* name, CUresult err);
|
|
|
|
// The macro version that prints the command that failed.
|
|
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
|
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
|
|
|
// Base class for RAII managed CUDA resources.
|
|
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
|
class CudaHandle {
|
|
public:
|
|
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
|
|
|
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
|
assert(this != &other);
|
|
other.handle_ = nullptr;
|
|
}
|
|
|
|
~CudaHandle() {
|
|
reset();
|
|
}
|
|
|
|
CudaHandle(const CudaHandle&) = delete;
|
|
CudaHandle& operator=(const CudaHandle&) = delete;
|
|
|
|
CudaHandle& operator=(CudaHandle&& other) {
|
|
assert(this != &other);
|
|
reset();
|
|
std::swap(handle_, other.handle_);
|
|
return *this;
|
|
}
|
|
|
|
void reset() {
|
|
if (handle_ != nullptr) {
|
|
CHECK_CUDA_ERROR(Destroy(handle_));
|
|
handle_ = nullptr;
|
|
}
|
|
}
|
|
|
|
operator Handle() const {
|
|
return handle_;
|
|
}
|
|
|
|
protected:
|
|
Handle handle_;
|
|
};
|
|
|
|
namespace cu {
|
|
class Device;
|
|
}; // namespace cu
|
|
|
|
// Wrappers of CUDA resources.
|
|
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
|
public:
|
|
using CudaHandle::CudaHandle;
|
|
explicit CudaGraph(cu::Device& device);
|
|
void end_capture(cudaStream_t stream);
|
|
};
|
|
|
|
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
|
public:
|
|
void instantiate(cudaGraph_t graph);
|
|
};
|
|
|
|
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
|
public:
|
|
explicit CudaStream(cu::Device& device);
|
|
};
|
|
|
|
} // namespace mlx::core
|