mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	[CUDA] speedup handling scalars (#2389)
* speedup scalars in cuda * comment
This commit is contained in:
		| @@ -17,6 +17,52 @@ namespace cu { | ||||
|  | ||||
| constexpr int page_size = 16384; | ||||
|  | ||||
| // Any allocations smaller than this will try to use the small pool | ||||
| constexpr int small_block_size = 8; | ||||
|  | ||||
| // The small pool size in bytes. This should be a multiple of the host page | ||||
| // size and small_block_size. | ||||
| constexpr int small_pool_size = 4 * page_size; | ||||
|  | ||||
| SmallSizePool::SmallSizePool() { | ||||
|   CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size)); | ||||
|   end_ = reinterpret_cast<void*>( | ||||
|       reinterpret_cast<char*>(buffer_) + small_pool_size); | ||||
|   next_free_ = reinterpret_cast<Block*>(buffer_); | ||||
|  | ||||
|   auto num_blocks = small_pool_size / small_block_size; | ||||
|   auto curr = next_free_; | ||||
|   for (size_t i = 0; i < num_blocks - 1; ++i) { | ||||
|     curr->next = reinterpret_cast<Block*>( | ||||
|         reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size); | ||||
|     curr = curr->next; | ||||
|   } | ||||
|   curr->next = nullptr; | ||||
| } | ||||
|  | ||||
| SmallSizePool::~SmallSizePool() { | ||||
|   CHECK_CUDA_ERROR(cudaFree(buffer_)); | ||||
| } | ||||
|  | ||||
| void* SmallSizePool::malloc() { | ||||
|   if (next_free_ == nullptr) { | ||||
|     return nullptr; | ||||
|   } | ||||
|   Block* b = next_free_; | ||||
|   next_free_ = next_free_->next; | ||||
|   return static_cast<void*>(b); | ||||
| } | ||||
|  | ||||
| void SmallSizePool::free(void* p) { | ||||
|   auto b = static_cast<Block*>(p); | ||||
|   b->next = next_free_; | ||||
|   next_free_ = b; | ||||
| } | ||||
|  | ||||
| bool SmallSizePool::in_pool(void* p) { | ||||
|   return (p >= buffer_) && (p < end_); | ||||
| } | ||||
|  | ||||
| CudaAllocator::CudaAllocator() | ||||
|     : buffer_cache_( | ||||
|           page_size, | ||||
| @@ -36,7 +82,9 @@ Buffer CudaAllocator::malloc(size_t size) { | ||||
|   // Find available buffer from cache. | ||||
|   auto orig_size = size; | ||||
|   std::unique_lock lock(mutex_); | ||||
|   if (size < page_size) { | ||||
|   if (size <= small_block_size) { | ||||
|     size = 8; | ||||
|   } else if (size < page_size) { | ||||
|     size = next_power_of_2(size); | ||||
|   } else { | ||||
|     size = page_size * ((size + page_size - 1) / page_size); | ||||
| @@ -53,11 +101,19 @@ Buffer CudaAllocator::malloc(size_t size) { | ||||
|  | ||||
|     lock.unlock(); | ||||
|     buf = new CudaBuffer{nullptr, size}; | ||||
|  | ||||
|     // Try the scalar pool first | ||||
|     if (size <= small_block_size) { | ||||
|       buf->data = scalar_pool_.malloc(); | ||||
|     } | ||||
|     if (!buf->data) { | ||||
|       cudaError_t err = cudaMallocManaged(&buf->data, size); | ||||
|       if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { | ||||
|         throw std::runtime_error(fmt::format( | ||||
|             "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     lock.lock(); | ||||
|   } | ||||
|   active_memory_ += size; | ||||
| @@ -116,8 +172,12 @@ void CudaAllocator::cuda_free(void* buf) { | ||||
|       return; | ||||
|     } | ||||
|   } | ||||
|   if (scalar_pool_.in_pool(buf)) { | ||||
|     scalar_pool_.free(buf); | ||||
|   } else { | ||||
|     cudaFree(buf); | ||||
|   } | ||||
| } | ||||
|  | ||||
| size_t CudaAllocator::get_active_memory() const { | ||||
|   return active_memory_; | ||||
|   | ||||
| @@ -22,6 +22,28 @@ struct CudaBuffer { | ||||
|   size_t size; | ||||
| }; | ||||
|  | ||||
| class SmallSizePool { | ||||
|  private: | ||||
|   struct Block { | ||||
|     Block* next; | ||||
|   }; | ||||
|  | ||||
|   void* buffer_{nullptr}; | ||||
|   Block* next_free_{nullptr}; | ||||
|   void* end_{nullptr}; | ||||
|  | ||||
|  public: | ||||
|   SmallSizePool(); | ||||
|   ~SmallSizePool(); | ||||
|  | ||||
|   SmallSizePool(const SmallSizePool&) = delete; | ||||
|   SmallSizePool& operator=(const SmallSizePool&) = delete; | ||||
|  | ||||
|   void* malloc(); | ||||
|   void free(void* p); | ||||
|   bool in_pool(void* p); | ||||
| }; | ||||
|  | ||||
| class CudaAllocator : public allocator::Allocator { | ||||
|  public: | ||||
|   Buffer malloc(size_t size) override; | ||||
| @@ -60,6 +82,7 @@ class CudaAllocator : public allocator::Allocator { | ||||
|   BufferCache<CudaBuffer> buffer_cache_; | ||||
|   size_t active_memory_{0}; | ||||
|   size_t peak_memory_{0}; | ||||
|   SmallSizePool scalar_pool_; | ||||
| }; | ||||
|  | ||||
| CudaAllocator& allocator(); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun