mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 12:09:43 +08:00
remove use of cuda pool, use cuda free async
This commit is contained in:
@@ -349,7 +349,10 @@ class array {
|
|||||||
return array_desc_->data;
|
return array_desc_->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a raw pointer to the arrays data
|
// Return a raw pointer to the arrays data. This function may do a copy if
|
||||||
|
// the underlying buffer is not accessible on the CPU. When accessing the
|
||||||
|
// data for GPU kernels, be sure to use the correct method / function for the
|
||||||
|
// given backend to access the GPU pointer.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
return reinterpret_cast<T*>(
|
return reinterpret_cast<T*>(
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
|
|||||||
next_free_ = next_free_->next;
|
next_free_ = next_free_->next;
|
||||||
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||||
b->buf.size = small_block_size;
|
b->buf.size = small_block_size;
|
||||||
b->buf.managed = true;
|
b->buf.device = -1;
|
||||||
return &b->buf;
|
return &b->buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,10 +94,15 @@ CudaAllocator::CudaAllocator()
|
|||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
memory_limit_ = total * 0.95;
|
memory_limit_ = total * 0.95;
|
||||||
max_pool_size_ = memory_limit_;
|
max_pool_size_ = memory_limit_;
|
||||||
int loc = 0;
|
|
||||||
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&cuda_pool_, loc));
|
int device_count = 0;
|
||||||
CHECK_CUDA_ERROR(cudaMemPoolSetAttribute(
|
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||||
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_));
|
int curr_device = 0;
|
||||||
|
CHECK_CUDA_ERROR(cudaGetDevice(&curr_device));
|
||||||
|
for (int i = 0; i < device_count; ++i) {
|
||||||
|
free_streams_.emplace_back(
|
||||||
|
cu::device(mlx::core::Device{mlx::core::Device::gpu, i}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||||
@@ -127,13 +132,16 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
bool managed = stream == nullptr;
|
int device = -1;
|
||||||
buf = new CudaBuffer{nullptr, size, managed};
|
if (stream != nullptr) {
|
||||||
|
cudaStreamGetDevice(stream, &device);
|
||||||
|
}
|
||||||
|
buf = new CudaBuffer{nullptr, size, device};
|
||||||
cudaError_t err;
|
cudaError_t err;
|
||||||
if (managed) {
|
if (device == -1) {
|
||||||
err = cudaMallocManaged(&buf->data, size);
|
err = cudaMallocManaged(&buf->data, size);
|
||||||
} else {
|
} else {
|
||||||
err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream);
|
err = cudaMallocAsync(&buf->data, size, stream);
|
||||||
}
|
}
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
@@ -188,7 +196,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
|||||||
if (scalar_pool_.in_pool(buf)) {
|
if (scalar_pool_.in_pool(buf)) {
|
||||||
scalar_pool_.free(buf);
|
scalar_pool_.free(buf);
|
||||||
} else {
|
} else {
|
||||||
cudaFree(buf->data);
|
if (buf->device >= 0) {
|
||||||
|
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
||||||
|
} else {
|
||||||
|
cudaFree(buf->data);
|
||||||
|
}
|
||||||
delete buf;
|
delete buf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -213,9 +225,6 @@ size_t CudaAllocator::get_memory_limit() {
|
|||||||
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||||
std::lock_guard lock(mutex_);
|
std::lock_guard lock(mutex_);
|
||||||
std::swap(limit, memory_limit_);
|
std::swap(limit, memory_limit_);
|
||||||
CHECK_CUDA_ERROR(cudaMemPoolTrimTo(cuda_pool_, memory_limit_));
|
|
||||||
CHECK_CUDA_ERROR(cudaMemPoolSetAttribute(
|
|
||||||
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_));
|
|
||||||
return limit;
|
return limit;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,12 +274,12 @@ void* Buffer::raw_ptr() {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
|
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
|
||||||
if (!cbuf.managed) {
|
if (cbuf.device != -1) {
|
||||||
// TODO maybe make this async on a i/o stream to avoid synchronizing the
|
// TODO maybe make this async on a i/o stream to avoid synchronizing the
|
||||||
// device on malloc/and free
|
// device on malloc/and free
|
||||||
void* new_data;
|
void* new_data;
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
|
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
|
||||||
cbuf.managed = true;
|
cbuf.device = -1;
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
|
cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
|
||||||
CHECK_CUDA_ERROR(cudaFree(cbuf.data));
|
CHECK_CUDA_ERROR(cudaFree(cbuf.data));
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/buffer_cache.h"
|
#include "mlx/backend/common/buffer_cache.h"
|
||||||
|
#include "mlx/backend/cuda/cuda_utils.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
@@ -18,11 +19,11 @@ using allocator::Buffer;
|
|||||||
struct CudaBuffer {
|
struct CudaBuffer {
|
||||||
void* data;
|
void* data;
|
||||||
size_t size;
|
size_t size;
|
||||||
bool managed;
|
int device; // -1 for managed
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* gpu_ptr(Buffer buf) {
|
inline T* gpu_ptr(Buffer buf) {
|
||||||
return static_cast<T*>(static_cast<cu::CudaBuffer*>(buf.ptr())->data);
|
return static_cast<T*>(static_cast<cu::CudaBuffer*>(buf.ptr())->data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,8 +79,8 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
BufferCache<CudaBuffer> buffer_cache_;
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
|
std::vector<CudaStream> free_streams_;
|
||||||
SmallSizePool scalar_pool_;
|
SmallSizePool scalar_pool_;
|
||||||
cudaMemPool_t cuda_pool_{nullptr};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
CudaAllocator& allocator();
|
CudaAllocator& allocator();
|
||||||
|
|||||||
82
mlx/backend/cuda/cuda_utils.h
Normal file
82
mlx/backend/cuda/cuda_utils.h
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cublasLt.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Throw exception if the cuda API does not succeed.
|
||||||
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||||
|
void check_cuda_error(const char* name, cudaError_t err);
|
||||||
|
void check_cuda_error(const char* name, CUresult err);
|
||||||
|
|
||||||
|
// The macro version that prints the command that failed.
|
||||||
|
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||||
|
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
// Base class for RAII managed CUDA resources.
|
||||||
|
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||||
|
class CudaHandle {
|
||||||
|
public:
|
||||||
|
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
||||||
|
|
||||||
|
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
||||||
|
assert(this != &other);
|
||||||
|
other.handle_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
~CudaHandle() {
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaHandle(const CudaHandle&) = delete;
|
||||||
|
CudaHandle& operator=(const CudaHandle&) = delete;
|
||||||
|
|
||||||
|
CudaHandle& operator=(CudaHandle&& other) {
|
||||||
|
assert(this != &other);
|
||||||
|
reset();
|
||||||
|
std::swap(handle_, other.handle_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
if (handle_ != nullptr) {
|
||||||
|
CHECK_CUDA_ERROR(Destroy(handle_));
|
||||||
|
handle_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
operator Handle() const {
|
||||||
|
return handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Handle handle_;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
class Device;
|
||||||
|
}; // namespace cu
|
||||||
|
|
||||||
|
// Wrappers of CUDA resources.
|
||||||
|
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
||||||
|
public:
|
||||||
|
using CudaHandle::CudaHandle;
|
||||||
|
explicit CudaGraph(cu::Device& device);
|
||||||
|
void end_capture(cudaStream_t stream);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
||||||
|
public:
|
||||||
|
void instantiate(cudaGraph_t graph);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||||
|
public:
|
||||||
|
explicit CudaStream(cu::Device& device);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -4,101 +4,12 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cublasLt.h>
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
|
#include "mlx/backend/cuda/cuda_utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
class Device;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T* gpu_ptr(array& arr) {
|
|
||||||
return cu::gpu_ptr<T>(arr.buffer());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
const T* gpu_ptr(const array& arr) {
|
|
||||||
return cu::gpu_ptr<T>(arr.buffer());
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Dtype;
|
|
||||||
|
|
||||||
// Throw exception if the cuda API does not succeed.
|
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
|
||||||
void check_cuda_error(const char* name, cudaError_t err);
|
|
||||||
void check_cuda_error(const char* name, CUresult err);
|
|
||||||
|
|
||||||
// The macro version that prints the command that failed.
|
|
||||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
|
||||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
|
||||||
|
|
||||||
// Convert Dtype to CUDA C++ types.
|
|
||||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
|
||||||
|
|
||||||
// Base class for RAII managed CUDA resources.
|
|
||||||
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
|
||||||
class CudaHandle {
|
|
||||||
public:
|
|
||||||
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
|
||||||
|
|
||||||
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
|
||||||
assert(this != &other);
|
|
||||||
other.handle_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
~CudaHandle() {
|
|
||||||
reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaHandle(const CudaHandle&) = delete;
|
|
||||||
CudaHandle& operator=(const CudaHandle&) = delete;
|
|
||||||
|
|
||||||
CudaHandle& operator=(CudaHandle&& other) {
|
|
||||||
assert(this != &other);
|
|
||||||
reset();
|
|
||||||
std::swap(handle_, other.handle_);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
void reset() {
|
|
||||||
if (handle_ != nullptr) {
|
|
||||||
CHECK_CUDA_ERROR(Destroy(handle_));
|
|
||||||
handle_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
operator Handle() const {
|
|
||||||
return handle_;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Handle handle_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Wrappers of CUDA resources.
|
|
||||||
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
|
||||||
public:
|
|
||||||
using CudaHandle::CudaHandle;
|
|
||||||
explicit CudaGraph(cu::Device& device);
|
|
||||||
void end_capture(cudaStream_t stream);
|
|
||||||
};
|
|
||||||
|
|
||||||
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
|
||||||
public:
|
|
||||||
void instantiate(cudaGraph_t graph);
|
|
||||||
};
|
|
||||||
|
|
||||||
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
|
||||||
public:
|
|
||||||
explicit CudaStream(cu::Device& device);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline uint max_occupancy_block_dim(T kernel) {
|
inline uint max_occupancy_block_dim(T kernel) {
|
||||||
int _, block_dim;
|
int _, block_dim;
|
||||||
@@ -112,4 +23,19 @@ inline uint max_occupancy_block_dim(T kernel) {
|
|||||||
return block_dim;
|
return block_dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T* gpu_ptr(array& arr) {
|
||||||
|
return cu::gpu_ptr<T>(arr.buffer());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline const T* gpu_ptr(const array& arr) {
|
||||||
|
return cu::gpu_ptr<T>(arr.buffer());
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Dtype;
|
||||||
|
|
||||||
|
// Convert Dtype to CUDA C++ types.
|
||||||
|
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -326,8 +326,8 @@ class NCCLGroup : public GroupImpl {
|
|||||||
auto& encoder = cu::get_command_encoder(stream);
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
CHECK_NCCL(ncclAllReduce(
|
CHECK_NCCL(ncclAllReduce(
|
||||||
input.data<T>(),
|
gpu_ptr<T>(input),
|
||||||
output.data<T>(),
|
gpu_ptr<T>(output),
|
||||||
input.size(),
|
input.size(),
|
||||||
dt,
|
dt,
|
||||||
op,
|
op,
|
||||||
|
|||||||
Reference in New Issue
Block a user