mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 03:58:12 +08:00
remove use of cuda pool, use cuda free async
This commit is contained in:
@@ -349,7 +349,10 @@ class array {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
// Return a raw pointer to the arrays data. This function may do a copy if
|
||||
// the underlying buffer is not accessible on the CPU. When accessing the
|
||||
// data for GPU kernels, be sure to use the correct method / function for the
|
||||
// given backend to access the GPU pointer.
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return reinterpret_cast<T*>(
|
||||
|
||||
@@ -68,7 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
|
||||
next_free_ = next_free_->next;
|
||||
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||
b->buf.size = small_block_size;
|
||||
b->buf.managed = true;
|
||||
b->buf.device = -1;
|
||||
return &b->buf;
|
||||
}
|
||||
|
||||
@@ -94,10 +94,15 @@ CudaAllocator::CudaAllocator()
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.95;
|
||||
max_pool_size_ = memory_limit_;
|
||||
int loc = 0;
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&cuda_pool_, loc));
|
||||
CHECK_CUDA_ERROR(cudaMemPoolSetAttribute(
|
||||
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_));
|
||||
|
||||
int device_count = 0;
|
||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||
int curr_device = 0;
|
||||
CHECK_CUDA_ERROR(cudaGetDevice(&curr_device));
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
free_streams_.emplace_back(
|
||||
cu::device(mlx::core::Device{mlx::core::Device::gpu, i}));
|
||||
}
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||
@@ -127,13 +132,16 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||
}
|
||||
lock.unlock();
|
||||
if (!buf) {
|
||||
bool managed = stream == nullptr;
|
||||
buf = new CudaBuffer{nullptr, size, managed};
|
||||
int device = -1;
|
||||
if (stream != nullptr) {
|
||||
cudaStreamGetDevice(stream, &device);
|
||||
}
|
||||
buf = new CudaBuffer{nullptr, size, device};
|
||||
cudaError_t err;
|
||||
if (managed) {
|
||||
if (device == -1) {
|
||||
err = cudaMallocManaged(&buf->data, size);
|
||||
} else {
|
||||
err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream);
|
||||
err = cudaMallocAsync(&buf->data, size, stream);
|
||||
}
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
@@ -188,7 +196,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
if (scalar_pool_.in_pool(buf)) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
if (buf->device >= 0) {
|
||||
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
}
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
@@ -213,9 +225,6 @@ size_t CudaAllocator::get_memory_limit() {
|
||||
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||
std::lock_guard lock(mutex_);
|
||||
std::swap(limit, memory_limit_);
|
||||
CHECK_CUDA_ERROR(cudaMemPoolTrimTo(cuda_pool_, memory_limit_));
|
||||
CHECK_CUDA_ERROR(cudaMemPoolSetAttribute(
|
||||
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_));
|
||||
return limit;
|
||||
}
|
||||
|
||||
@@ -265,12 +274,12 @@ void* Buffer::raw_ptr() {
|
||||
return nullptr;
|
||||
}
|
||||
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
|
||||
if (!cbuf.managed) {
|
||||
if (cbuf.device != -1) {
|
||||
// TODO maybe make this async on a i/o stream to avoid synchronizing the
|
||||
// device on malloc/and free
|
||||
void* new_data;
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
|
||||
cbuf.managed = true;
|
||||
cbuf.device = -1;
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
|
||||
CHECK_CUDA_ERROR(cudaFree(cbuf.data));
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
#include "mlx/backend/cuda/cuda_utils.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <mutex>
|
||||
@@ -18,11 +19,11 @@ using allocator::Buffer;
|
||||
struct CudaBuffer {
|
||||
void* data;
|
||||
size_t size;
|
||||
bool managed;
|
||||
int device; // -1 for managed
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T* gpu_ptr(Buffer buf) {
|
||||
inline T* gpu_ptr(Buffer buf) {
|
||||
return static_cast<T*>(static_cast<cu::CudaBuffer*>(buf.ptr())->data);
|
||||
}
|
||||
|
||||
@@ -78,8 +79,8 @@ class CudaAllocator : public allocator::Allocator {
|
||||
BufferCache<CudaBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
std::vector<CudaStream> free_streams_;
|
||||
SmallSizePool scalar_pool_;
|
||||
cudaMemPool_t cuda_pool_{nullptr};
|
||||
};
|
||||
|
||||
CudaAllocator& allocator();
|
||||
|
||||
82
mlx/backend/cuda/cuda_utils.h
Normal file
82
mlx/backend/cuda/cuda_utils.h
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
void check_cuda_error(const char* name, CUresult err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
// Base class for RAII managed CUDA resources.
|
||||
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||
class CudaHandle {
|
||||
public:
|
||||
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
||||
|
||||
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
||||
assert(this != &other);
|
||||
other.handle_ = nullptr;
|
||||
}
|
||||
|
||||
~CudaHandle() {
|
||||
reset();
|
||||
}
|
||||
|
||||
CudaHandle(const CudaHandle&) = delete;
|
||||
CudaHandle& operator=(const CudaHandle&) = delete;
|
||||
|
||||
CudaHandle& operator=(CudaHandle&& other) {
|
||||
assert(this != &other);
|
||||
reset();
|
||||
std::swap(handle_, other.handle_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (handle_ != nullptr) {
|
||||
CHECK_CUDA_ERROR(Destroy(handle_));
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
operator Handle() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
protected:
|
||||
Handle handle_;
|
||||
};
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
}; // namespace cu
|
||||
|
||||
// Wrappers of CUDA resources.
|
||||
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
||||
public:
|
||||
using CudaHandle::CudaHandle;
|
||||
explicit CudaGraph(cu::Device& device);
|
||||
void end_capture(cudaStream_t stream);
|
||||
};
|
||||
|
||||
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
||||
public:
|
||||
void instantiate(cudaGraph_t graph);
|
||||
};
|
||||
|
||||
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -4,101 +4,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/cuda_utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* gpu_ptr(array& arr) {
|
||||
return cu::gpu_ptr<T>(arr.buffer());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* gpu_ptr(const array& arr) {
|
||||
return cu::gpu_ptr<T>(arr.buffer());
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
void check_cuda_error(const char* name, CUresult err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
// Convert Dtype to CUDA C++ types.
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||
|
||||
// Base class for RAII managed CUDA resources.
|
||||
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||
class CudaHandle {
|
||||
public:
|
||||
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
||||
|
||||
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
||||
assert(this != &other);
|
||||
other.handle_ = nullptr;
|
||||
}
|
||||
|
||||
~CudaHandle() {
|
||||
reset();
|
||||
}
|
||||
|
||||
CudaHandle(const CudaHandle&) = delete;
|
||||
CudaHandle& operator=(const CudaHandle&) = delete;
|
||||
|
||||
CudaHandle& operator=(CudaHandle&& other) {
|
||||
assert(this != &other);
|
||||
reset();
|
||||
std::swap(handle_, other.handle_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (handle_ != nullptr) {
|
||||
CHECK_CUDA_ERROR(Destroy(handle_));
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
operator Handle() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
protected:
|
||||
Handle handle_;
|
||||
};
|
||||
|
||||
// Wrappers of CUDA resources.
|
||||
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
||||
public:
|
||||
using CudaHandle::CudaHandle;
|
||||
explicit CudaGraph(cu::Device& device);
|
||||
void end_capture(cudaStream_t stream);
|
||||
};
|
||||
|
||||
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
||||
public:
|
||||
void instantiate(cudaGraph_t graph);
|
||||
};
|
||||
|
||||
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
@@ -112,4 +23,19 @@ inline uint max_occupancy_block_dim(T kernel) {
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T* gpu_ptr(array& arr) {
|
||||
return cu::gpu_ptr<T>(arr.buffer());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline const T* gpu_ptr(const array& arr) {
|
||||
return cu::gpu_ptr<T>(arr.buffer());
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
|
||||
// Convert Dtype to CUDA C++ types.
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -326,8 +326,8 @@ class NCCLGroup : public GroupImpl {
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
|
||||
CHECK_NCCL(ncclAllReduce(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
gpu_ptr<T>(input),
|
||||
gpu_ptr<T>(output),
|
||||
input.size(),
|
||||
dt,
|
||||
op,
|
||||
|
||||
Reference in New Issue
Block a user