mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Use async cuda malloc managed with cuda 13
This commit is contained in:
		| @@ -1,6 +1,7 @@ | |||||||
| // Copyright © 2025 Apple Inc. | // Copyright © 2025 Apple Inc. | ||||||
|  |  | ||||||
| #include "mlx/backend/cuda/allocator.h" | #include "mlx/backend/cuda/allocator.h" | ||||||
|  | #include "mlx/backend/cuda/device.h" | ||||||
| #include "mlx/backend/cuda/utils.h" | #include "mlx/backend/cuda/utils.h" | ||||||
| #include "mlx/utils.h" | #include "mlx/utils.h" | ||||||
|  |  | ||||||
| @@ -93,9 +94,17 @@ CudaAllocator::CudaAllocator() | |||||||
|   CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); |   CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); | ||||||
|   memory_limit_ = total * 0.95; |   memory_limit_ = total * 0.95; | ||||||
|   max_pool_size_ = memory_limit_; |   max_pool_size_ = memory_limit_; | ||||||
|  | #if CUDART_VERSION >= 13000 | ||||||
|  |   cudaMemLocation loc; | ||||||
|  |   loc.id = 0; | ||||||
|  |   loc.type = cudaMemLocationTypeNone; | ||||||
|  |   cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged); | ||||||
|  |   // TODO set that. | ||||||
|  |   // uint64_t threshold = UINT64_MAX; | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| Buffer CudaAllocator::malloc(size_t size) { | Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { | ||||||
|   // Find available buffer from cache. |   // Find available buffer from cache. | ||||||
|   auto orig_size = size; |   auto orig_size = size; | ||||||
|   std::unique_lock lock(mutex_); |   std::unique_lock lock(mutex_); | ||||||
| @@ -123,7 +132,12 @@ Buffer CudaAllocator::malloc(size_t size) { | |||||||
|     lock.unlock(); |     lock.unlock(); | ||||||
|     if (!buf) { |     if (!buf) { | ||||||
|       buf = new CudaBuffer{nullptr, size}; |       buf = new CudaBuffer{nullptr, size}; | ||||||
|       cudaError_t err = cudaMallocManaged(&buf->data, size); |       cudaError_t err; | ||||||
|  |       if (stream != nullptr && cuda_pool_ != nullptr) { | ||||||
|  |         err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream); | ||||||
|  |       } else { | ||||||
|  |         err = cudaMallocManaged(&buf->data, size); | ||||||
|  |       } | ||||||
|       if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { |       if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { | ||||||
|         throw std::runtime_error(fmt::format( |         throw std::runtime_error(fmt::format( | ||||||
|             "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); |             "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); | ||||||
| @@ -141,6 +155,14 @@ Buffer CudaAllocator::malloc(size_t size) { | |||||||
|   return Buffer{buf}; |   return Buffer{buf}; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) { | ||||||
|  |   return malloc_impl(size, stream); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | Buffer CudaAllocator::malloc(size_t size) { | ||||||
|  |   return malloc_impl(size, nullptr); | ||||||
|  | } | ||||||
|  |  | ||||||
| void CudaAllocator::free(Buffer buffer) { | void CudaAllocator::free(Buffer buffer) { | ||||||
|   auto* buf = static_cast<CudaBuffer*>(buffer.ptr()); |   auto* buf = static_cast<CudaBuffer*>(buffer.ptr()); | ||||||
|   if (!buf) { |   if (!buf) { | ||||||
| @@ -220,6 +242,16 @@ CudaAllocator& allocator() { | |||||||
|   return *allocator_; |   return *allocator_; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | Buffer malloc_async(size_t size, cudaStream_t stream) { | ||||||
|  |   auto buffer = allocator().malloc_async(size, stream); | ||||||
|  |   if (size && !buffer.ptr()) { | ||||||
|  |     std::ostringstream msg; | ||||||
|  |     msg << "[malloc_async] Unable to allocate " << size << " bytes."; | ||||||
|  |     throw std::runtime_error(msg.str()); | ||||||
|  |   } | ||||||
|  |   return buffer; | ||||||
|  | } | ||||||
|  |  | ||||||
| } // namespace cu | } // namespace cu | ||||||
|  |  | ||||||
| namespace allocator { | namespace allocator { | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ | |||||||
| #include "mlx/allocator.h" | #include "mlx/allocator.h" | ||||||
| #include "mlx/backend/common/buffer_cache.h" | #include "mlx/backend/common/buffer_cache.h" | ||||||
|  |  | ||||||
|  | #include <cuda_runtime.h> | ||||||
| #include <mutex> | #include <mutex> | ||||||
| #include <set> | #include <set> | ||||||
| #include <utility> | #include <utility> | ||||||
| @@ -45,6 +46,7 @@ class SmallSizePool { | |||||||
| class CudaAllocator : public allocator::Allocator { | class CudaAllocator : public allocator::Allocator { | ||||||
|  public: |  public: | ||||||
|   Buffer malloc(size_t size) override; |   Buffer malloc(size_t size) override; | ||||||
|  |   Buffer malloc_async(size_t size, cudaStream_t stream); | ||||||
|   void free(Buffer buffer) override; |   void free(Buffer buffer) override; | ||||||
|   size_t size(Buffer buffer) const override; |   size_t size(Buffer buffer) const override; | ||||||
|  |  | ||||||
| @@ -58,6 +60,7 @@ class CudaAllocator : public allocator::Allocator { | |||||||
|   void clear_cache(); |   void clear_cache(); | ||||||
|  |  | ||||||
|  private: |  private: | ||||||
|  |   Buffer malloc_impl(size_t size, cudaStream_t stream); | ||||||
|   void cuda_free(CudaBuffer* buf); |   void cuda_free(CudaBuffer* buf); | ||||||
|  |  | ||||||
|   CudaAllocator(); |   CudaAllocator(); | ||||||
| @@ -70,8 +73,11 @@ class CudaAllocator : public allocator::Allocator { | |||||||
|   size_t active_memory_{0}; |   size_t active_memory_{0}; | ||||||
|   size_t peak_memory_{0}; |   size_t peak_memory_{0}; | ||||||
|   SmallSizePool scalar_pool_; |   SmallSizePool scalar_pool_; | ||||||
|  |   cudaMemPool_t cuda_pool_{nullptr}; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| CudaAllocator& allocator(); | CudaAllocator& allocator(); | ||||||
|  |  | ||||||
|  | Buffer malloc_async(size_t size, cudaStream_t stream); | ||||||
|  |  | ||||||
| } // namespace mlx::core::cu | } // namespace mlx::core::cu | ||||||
|   | |||||||
| @@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   if (out.size() == 0) { |   if (out.size() == 0) { | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|   out.set_data(allocator::malloc(out.nbytes())); |  | ||||||
|  |  | ||||||
|   auto& encoder = cu::get_command_encoder(stream()); |   auto& encoder = cu::get_command_encoder(stream()); | ||||||
|  |   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|   encoder.set_output_array(out); |   encoder.set_output_array(out); | ||||||
|  |  | ||||||
|   dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { |   dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { | ||||||
|   | |||||||
| @@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   nvtx3::scoped_range r("ArgReduce::eval_gpu"); |   nvtx3::scoped_range r("ArgReduce::eval_gpu"); | ||||||
|   assert(inputs.size() == 1); |   assert(inputs.size() == 1); | ||||||
|   auto& in = inputs[0]; |   auto& in = inputs[0]; | ||||||
|   out.set_data(allocator::malloc(out.nbytes())); |  | ||||||
|   auto& s = stream(); |   auto& s = stream(); | ||||||
|  |   auto& encoder = cu::get_command_encoder(s); | ||||||
|  |   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|  |  | ||||||
|   // Prepare the shapes, strides and axis arguments. |   // Prepare the shapes, strides and axis arguments. | ||||||
|   Shape shape = remove_index(in.shape(), axis_); |   Shape shape = remove_index(in.shape(), axis_); | ||||||
| @@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   int32_t ndim = shape.size(); |   int32_t ndim = shape.size(); | ||||||
|  |  | ||||||
|   // ArgReduce. |   // ArgReduce. | ||||||
|   auto& encoder = cu::get_command_encoder(s); |  | ||||||
|   encoder.set_input_array(in); |   encoder.set_input_array(in); | ||||||
|   encoder.set_output_array(out); |   encoder.set_output_array(out); | ||||||
|   dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { |   dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { | ||||||
|   | |||||||
| @@ -87,8 +87,8 @@ void fill_gpu(const array& in, array& out, const Stream& s) { | |||||||
|   if (out.size() == 0) { |   if (out.size() == 0) { | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|   out.set_data(allocator::malloc(out.nbytes())); |  | ||||||
|   auto& encoder = cu::get_command_encoder(s); |   auto& encoder = cu::get_command_encoder(s); | ||||||
|  |   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|   encoder.set_input_array(in); |   encoder.set_input_array(in); | ||||||
|   encoder.set_output_array(out); |   encoder.set_output_array(out); | ||||||
|   copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); |   copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| #pragma once | #pragma once | ||||||
|  |  | ||||||
| #include "mlx/array.h" | #include "mlx/array.h" | ||||||
|  | #include "mlx/backend/cuda/allocator.h" | ||||||
| #include "mlx/backend/cuda/lru_cache.h" | #include "mlx/backend/cuda/lru_cache.h" | ||||||
| #include "mlx/backend/cuda/worker.h" | #include "mlx/backend/cuda/worker.h" | ||||||
| #include "mlx/stream.h" | #include "mlx/stream.h" | ||||||
|   | |||||||
| @@ -370,7 +370,7 @@ void CublasGemm::execute( | |||||||
|     // Ensure workspace is 256-byte aligned |     // Ensure workspace is 256-byte aligned | ||||||
|     int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; |     int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; | ||||||
|     array workspace( |     array workspace( | ||||||
|         allocator::malloc(nbytes), |         cu::malloc_async(nbytes, encoder.stream()), | ||||||
|         {static_cast<int>(heuristic_.workspaceSize)}, |         {static_cast<int>(heuristic_.workspaceSize)}, | ||||||
|         int8); |         int8); | ||||||
|     encoder.add_temporary(workspace); |     encoder.add_temporary(workspace); | ||||||
|   | |||||||
| @@ -163,7 +163,7 @@ void CublasGemm::run_batched( | |||||||
|  |  | ||||||
|   // Launch kernel to set device offsets |   // Launch kernel to set device offsets | ||||||
|   auto pointers = array( |   auto pointers = array( | ||||||
|       allocator::malloc(batch_count * sizeof(void*) * 3), |       cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()), | ||||||
|       {batch_count * 3}, |       {batch_count * 3}, | ||||||
|       uint64); |       uint64); | ||||||
|  |  | ||||||
| @@ -251,7 +251,7 @@ void CublasGemm::run_batched( | |||||||
|  |  | ||||||
|   // Launch kernel to set device offsets |   // Launch kernel to set device offsets | ||||||
|   auto pointers = array( |   auto pointers = array( | ||||||
|       allocator::malloc(batch_count * sizeof(uint64_t) * 4), |       cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()), | ||||||
|       {batch_count * 4}, |       {batch_count * 4}, | ||||||
|       uint64); |       uint64); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   assert(inputs.size() > 0); |   assert(inputs.size() > 0); | ||||||
|   const auto& src = inputs[0]; |   const auto& src = inputs[0]; | ||||||
|  |  | ||||||
|   out.set_data(allocator::malloc(out.nbytes())); |   auto& s = stream(); | ||||||
|  |   auto& encoder = cu::get_command_encoder(s); | ||||||
|  |   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|   if (out.size() == 0) { |   if (out.size() == 0) { | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
| @@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|       dtype_to_string(idx_dtype), |       dtype_to_string(idx_dtype), | ||||||
|       nidx); |       nidx); | ||||||
|  |  | ||||||
|   auto& s = stream(); |  | ||||||
|   cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { |   cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { | ||||||
|     std::vector<std::string> kernel_names; |     std::vector<std::string> kernel_names; | ||||||
|     for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { |     for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { | ||||||
| @@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|       idx_ndim, |       idx_ndim, | ||||||
|       large ? "int64_t" : "int32_t"); |       large ? "int64_t" : "int32_t"); | ||||||
|  |  | ||||||
|   auto& encoder = cu::get_command_encoder(s); |  | ||||||
|   for (const auto& in : inputs) { |   for (const auto& in : inputs) { | ||||||
|     encoder.set_input_array(in); |     encoder.set_input_array(in); | ||||||
|   } |   } | ||||||
| @@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   const auto& src = inputs[0]; |   const auto& src = inputs[0]; | ||||||
|   const auto& idx = inputs[1]; |   const auto& idx = inputs[1]; | ||||||
|  |  | ||||||
|   out.set_data(allocator::malloc(out.nbytes())); |   auto& s = stream(); | ||||||
|  |   auto& encoder = cu::get_command_encoder(s); | ||||||
|  |   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|   if (out.size() == 0) { |   if (out.size() == 0) { | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
| @@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|       dtype_to_string(out.dtype()), |       dtype_to_string(out.dtype()), | ||||||
|       dtype_to_string(idx.dtype())); |       dtype_to_string(idx.dtype())); | ||||||
|  |  | ||||||
|   auto& s = stream(); |  | ||||||
|   cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { |   cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { | ||||||
|     std::vector<std::string> kernel_names; |     std::vector<std::string> kernel_names; | ||||||
|     for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { |     for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { | ||||||
| @@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|       idx.flags().row_contiguous, |       idx.flags().row_contiguous, | ||||||
|       large ? "int64_t" : "int32_t"); |       large ? "int64_t" : "int32_t"); | ||||||
|  |  | ||||||
|   auto& encoder = cu::get_command_encoder(s); |  | ||||||
|   for (const auto& in : inputs) { |   for (const auto& in : inputs) { | ||||||
|     encoder.set_input_array(in); |     encoder.set_input_array(in); | ||||||
|   } |   } | ||||||
|   | |||||||
| @@ -230,9 +230,10 @@ void LayerNorm::eval_gpu( | |||||||
|   nvtx3::scoped_range r("LayerNorm::eval_gpu"); |   nvtx3::scoped_range r("LayerNorm::eval_gpu"); | ||||||
|   auto& s = stream(); |   auto& s = stream(); | ||||||
|   auto& out = outputs[0]; |   auto& out = outputs[0]; | ||||||
|  |   auto& encoder = cu::get_command_encoder(s); | ||||||
|  |  | ||||||
|   // Make sure that the last dimension is contiguous. |   // Make sure that the last dimension is contiguous. | ||||||
|   auto set_output = [&s, &out](const array& x) { |   auto set_output = [&s, &out, &encoder](const array& x) { | ||||||
|     bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; |     bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; | ||||||
|     if (no_copy && x.ndim() > 1) { |     if (no_copy && x.ndim() > 1) { | ||||||
|       auto s = x.strides()[x.ndim() - 2]; |       auto s = x.strides()[x.ndim() - 2]; | ||||||
| @@ -243,7 +244,7 @@ void LayerNorm::eval_gpu( | |||||||
|         out.copy_shared_buffer(x); |         out.copy_shared_buffer(x); | ||||||
|       } else { |       } else { | ||||||
|         out.set_data( |         out.set_data( | ||||||
|             allocator::malloc(x.data_size() * x.itemsize()), |             cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), | ||||||
|             x.data_size(), |             x.data_size(), | ||||||
|             x.strides(), |             x.strides(), | ||||||
|             x.flags()); |             x.flags()); | ||||||
| @@ -265,7 +266,6 @@ void LayerNorm::eval_gpu( | |||||||
|   int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; |   int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; | ||||||
|   int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; |   int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; | ||||||
|  |  | ||||||
|   auto& encoder = cu::get_command_encoder(s); |  | ||||||
|   encoder.set_input_array(x); |   encoder.set_input_array(x); | ||||||
|   encoder.set_input_array(w); |   encoder.set_input_array(w); | ||||||
|   encoder.set_input_array(b); |   encoder.set_input_array(b); | ||||||
| @@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu( | |||||||
|     gx.copy_shared_buffer(g); |     gx.copy_shared_buffer(g); | ||||||
|     g_in_gx = true; |     g_in_gx = true; | ||||||
|   } else { |   } else { | ||||||
|     gx.set_data(allocator::malloc(gx.nbytes())); |     gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); | ||||||
|   } |   } | ||||||
|   if (g_copied && !g_in_gx) { |   if (g_copied && !g_in_gx) { | ||||||
|     encoder.add_temporary(g); |     encoder.add_temporary(g); | ||||||
| @@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu( | |||||||
|       g_in_gw = true; |       g_in_gw = true; | ||||||
|       gw_temp.copy_shared_buffer(g); |       gw_temp.copy_shared_buffer(g); | ||||||
|     } else { |     } else { | ||||||
|       gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); |       gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); | ||||||
|       encoder.add_temporary(gw_temp); |       encoder.add_temporary(gw_temp); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   | |||||||
| @@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|  |  | ||||||
|   auto in = ensure_contiguous(inputs[0]); |   auto in = ensure_contiguous(inputs[0]); | ||||||
|   if (in.flags().row_contiguous) { |   if (in.flags().row_contiguous) { | ||||||
|     out.set_data(allocator::malloc(out.nbytes())); |     out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|   } else { |   } else { | ||||||
|     auto n = in.shape(-1); |     auto n = in.shape(-1); | ||||||
|     auto flags = in.flags(); |     auto flags = in.flags(); | ||||||
| @@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|     } |     } | ||||||
|     flags.col_contiguous = col_contig; |     flags.col_contiguous = col_contig; | ||||||
|     out.set_data( |     out.set_data( | ||||||
|         allocator::malloc(in.nbytes() / n), |         cu::malloc_async(in.nbytes() / n, encoder.stream()), | ||||||
|         in.data_size() / n, |         in.data_size() / n, | ||||||
|         std::move(strides), |         std::move(strides), | ||||||
|         flags); |         flags); | ||||||
|   | |||||||
| @@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   out.set_data(allocator::malloc(out.nbytes())); |   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|  |  | ||||||
|   int M = a_pre.shape(-2); |   int M = a_pre.shape(-2); | ||||||
|   int N = b_pre.shape(-1); |   int N = b_pre.shape(-1); | ||||||
| @@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|  |  | ||||||
|   if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 && |   if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 && | ||||||
|       c.data_size() == out.shape(-1)) { |       c.data_size() == out.shape(-1)) { | ||||||
|     out.set_data(allocator::malloc(out.nbytes())); |     out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|     gemm_and_bias( |     gemm_and_bias( | ||||||
|         encoder, |         encoder, | ||||||
|         M, |         M, | ||||||
| @@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|     auto sty = c.strides()[c.ndim() - 1]; |     auto sty = c.strides()[c.ndim() - 1]; | ||||||
|     if (sty == 1 && stx == c.shape(-1)) { |     if (sty == 1 && stx == c.shape(-1)) { | ||||||
|       ldc = stx; |       ldc = stx; | ||||||
|       out.set_data(allocator::malloc(out.nbytes())); |       out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|     } else if (sty == 1 && stx == 0) { |     } else if (sty == 1 && stx == 0) { | ||||||
|       ldc = 0; |       ldc = 0; | ||||||
|       out.set_data(allocator::malloc(out.nbytes())); |       out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|     } else { |     } else { | ||||||
|       // Copy C into out and set C to out |       // Copy C into out and set C to out | ||||||
|       ldc = c.shape(-1); |       ldc = c.shape(-1); | ||||||
|   | |||||||
| @@ -176,9 +176,10 @@ void RMSNorm::eval_gpu( | |||||||
|   nvtx3::scoped_range r("RMSNorm::eval_gpu"); |   nvtx3::scoped_range r("RMSNorm::eval_gpu"); | ||||||
|   auto& s = stream(); |   auto& s = stream(); | ||||||
|   auto& out = outputs[0]; |   auto& out = outputs[0]; | ||||||
|  |   auto& encoder = cu::get_command_encoder(s); | ||||||
|  |  | ||||||
|   // Make sure that the last dimension is contiguous. |   // Make sure that the last dimension is contiguous. | ||||||
|   auto set_output = [&s, &out](const array& x) { |   auto set_output = [&s, &out, &encoder](const array& x) { | ||||||
|     bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; |     bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; | ||||||
|     if (no_copy && x.ndim() > 1) { |     if (no_copy && x.ndim() > 1) { | ||||||
|       auto s = x.strides()[x.ndim() - 2]; |       auto s = x.strides()[x.ndim() - 2]; | ||||||
| @@ -189,7 +190,7 @@ void RMSNorm::eval_gpu( | |||||||
|         out.copy_shared_buffer(x); |         out.copy_shared_buffer(x); | ||||||
|       } else { |       } else { | ||||||
|         out.set_data( |         out.set_data( | ||||||
|             allocator::malloc(x.data_size() * x.itemsize()), |             cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), | ||||||
|             x.data_size(), |             x.data_size(), | ||||||
|             x.strides(), |             x.strides(), | ||||||
|             x.flags()); |             x.flags()); | ||||||
| @@ -209,7 +210,6 @@ void RMSNorm::eval_gpu( | |||||||
|   int32_t n_rows = x.data_size() / axis_size; |   int32_t n_rows = x.data_size() / axis_size; | ||||||
|   int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; |   int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; | ||||||
|  |  | ||||||
|   auto& encoder = cu::get_command_encoder(s); |  | ||||||
|   encoder.set_input_array(x); |   encoder.set_input_array(x); | ||||||
|   encoder.set_input_array(w); |   encoder.set_input_array(w); | ||||||
|   encoder.set_output_array(out); |   encoder.set_output_array(out); | ||||||
| @@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu( | |||||||
|     gx.copy_shared_buffer(g); |     gx.copy_shared_buffer(g); | ||||||
|     g_in_gx = true; |     g_in_gx = true; | ||||||
|   } else { |   } else { | ||||||
|     gx.set_data(allocator::malloc(gx.nbytes())); |     gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); | ||||||
|   } |   } | ||||||
|   if (g_copied && !g_in_gx) { |   if (g_copied && !g_in_gx) { | ||||||
|     encoder.add_temporary(g); |     encoder.add_temporary(g); | ||||||
| @@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu( | |||||||
|     if (!g_in_gx && donate_g) { |     if (!g_in_gx && donate_g) { | ||||||
|       gw_temp.copy_shared_buffer(g); |       gw_temp.copy_shared_buffer(g); | ||||||
|     } else { |     } else { | ||||||
|       gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); |       gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); | ||||||
|       encoder.add_temporary(gw_temp); |       encoder.add_temporary(gw_temp); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   | |||||||
| @@ -250,6 +250,7 @@ void RoPE::eval_gpu( | |||||||
|   nvtx3::scoped_range r("RoPE::eval_gpu"); |   nvtx3::scoped_range r("RoPE::eval_gpu"); | ||||||
|  |  | ||||||
|   auto& s = stream(); |   auto& s = stream(); | ||||||
|  |   auto& encoder = cu::get_command_encoder(s); | ||||||
|   auto& in = inputs[0]; |   auto& in = inputs[0]; | ||||||
|   auto& offset = inputs[1]; |   auto& offset = inputs[1]; | ||||||
|   auto& out = outputs[0]; |   auto& out = outputs[0]; | ||||||
| @@ -291,14 +292,14 @@ void RoPE::eval_gpu( | |||||||
|       donated = true; |       donated = true; | ||||||
|       out.copy_shared_buffer(in); |       out.copy_shared_buffer(in); | ||||||
|     } else { |     } else { | ||||||
|       out.set_data(allocator::malloc(out.nbytes())); |       out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|     } |     } | ||||||
|     strides[0] = mat_size; |     strides[0] = mat_size; | ||||||
|     strides[1] = in.strides()[ndim - 2]; |     strides[1] = in.strides()[ndim - 2]; | ||||||
|     strides[2] = in.strides()[ndim - 1]; |     strides[2] = in.strides()[ndim - 1]; | ||||||
|   } else if (dispatch_ndim == 3) { |   } else if (dispatch_ndim == 3) { | ||||||
|     // Handle non-contiguous 3D inputs |     // Handle non-contiguous 3D inputs | ||||||
|     out.set_data(allocator::malloc(out.nbytes())); |     out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|     strides[0] = in.strides()[ndim - 3]; |     strides[0] = in.strides()[ndim - 3]; | ||||||
|     strides[1] = in.strides()[ndim - 2]; |     strides[1] = in.strides()[ndim - 2]; | ||||||
|     strides[2] = in.strides()[ndim - 1]; |     strides[2] = in.strides()[ndim - 1]; | ||||||
| @@ -319,7 +320,6 @@ void RoPE::eval_gpu( | |||||||
|   bool single = in.flags().row_contiguous && B == 1 && T == 1; |   bool single = in.flags().row_contiguous && B == 1 && T == 1; | ||||||
|   bool with_freqs = inputs.size() == 3; |   bool with_freqs = inputs.size() == 3; | ||||||
|  |  | ||||||
|   auto& encoder = cu::get_command_encoder(s); |  | ||||||
|   encoder.set_input_array(donated ? out : in); |   encoder.set_input_array(donated ? out : in); | ||||||
|   encoder.set_input_array(offset); |   encoder.set_input_array(offset); | ||||||
|   if (with_freqs) { |   if (with_freqs) { | ||||||
|   | |||||||
| @@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback( | |||||||
|   array sums(intermediate_shape, float32, nullptr, {}); |   array sums(intermediate_shape, float32, nullptr, {}); | ||||||
|   array maxs(std::move(intermediate_shape), float32, nullptr, {}); |   array maxs(std::move(intermediate_shape), float32, nullptr, {}); | ||||||
|  |  | ||||||
|   intermediate.set_data(allocator::malloc(intermediate.nbytes())); |   intermediate.set_data( | ||||||
|   sums.set_data(allocator::malloc(sums.nbytes())); |       cu::malloc_async(intermediate.nbytes(), encoder.stream())); | ||||||
|   maxs.set_data(allocator::malloc(maxs.nbytes())); |   sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream())); | ||||||
|  |   maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream())); | ||||||
|  |  | ||||||
|   encoder.add_temporary(intermediate); |   encoder.add_temporary(intermediate); | ||||||
|   encoder.add_temporary(sums); |   encoder.add_temporary(sums); | ||||||
| @@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu( | |||||||
|       }; |       }; | ||||||
|  |  | ||||||
|       o.set_data( |       o.set_data( | ||||||
|           allocator::malloc(o.nbytes()), |           cu::malloc_async(o.nbytes(), encoder.stream()), | ||||||
|           o.size(), |           o.size(), | ||||||
|           {str_oB, str_oH, str_oL, str_oD}, |           {str_oB, str_oH, str_oL, str_oD}, | ||||||
|           flags); |           flags); | ||||||
|   | |||||||
| @@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   assert(inputs.size() == 1); |   assert(inputs.size() == 1); | ||||||
|   auto in = inputs[0]; |   auto in = inputs[0]; | ||||||
|   auto& s = stream(); |   auto& s = stream(); | ||||||
|  |   auto& encoder = cu::get_command_encoder(s); | ||||||
|  |  | ||||||
|   if (in.flags().contiguous && in.strides()[axis_] != 0) { |   if (in.flags().contiguous && in.strides()[axis_] != 0) { | ||||||
|     if (in.is_donatable() && in.itemsize() == out.itemsize()) { |     if (in.is_donatable() && in.itemsize() == out.itemsize()) { | ||||||
|       out.copy_shared_buffer(in); |       out.copy_shared_buffer(in); | ||||||
|     } else { |     } else { | ||||||
|       out.set_data( |       out.set_data( | ||||||
|           allocator::malloc(in.data_size() * out.itemsize()), |           cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), | ||||||
|           in.data_size(), |           in.data_size(), | ||||||
|           in.strides(), |           in.strides(), | ||||||
|           in.flags()); |           in.flags()); | ||||||
| @@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   int32_t axis_size = in.shape(axis_); |   int32_t axis_size = in.shape(axis_); | ||||||
|   bool contiguous = in.strides()[axis_] == 1; |   bool contiguous = in.strides()[axis_] == 1; | ||||||
|  |  | ||||||
|   auto& encoder = cu::get_command_encoder(s); |  | ||||||
|   encoder.set_input_array(in); |   encoder.set_input_array(in); | ||||||
|   encoder.set_output_array(out); |   encoder.set_output_array(out); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -23,14 +23,15 @@ void concatenate_gpu( | |||||||
|   } |   } | ||||||
|   std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); |   std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); | ||||||
|  |  | ||||||
|   out.set_data(allocator::malloc(out.nbytes())); |   auto& encoder = cu::get_command_encoder(s); | ||||||
|  |   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||||
|  |  | ||||||
|   auto strides = out.strides(); |   auto strides = out.strides(); | ||||||
|   auto flags = out.flags(); |   auto flags = out.flags(); | ||||||
|   flags.row_contiguous = false; |   flags.row_contiguous = false; | ||||||
|   flags.col_contiguous = false; |   flags.col_contiguous = false; | ||||||
|   flags.contiguous = false; |   flags.contiguous = false; | ||||||
|   auto concurrent = cu::get_command_encoder(s).concurrent_context(); |   auto concurrent = encoder.concurrent_context(); | ||||||
|   for (int i = 0; i < inputs.size(); i++) { |   for (int i = 0; i < inputs.size(); i++) { | ||||||
|     array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); |     array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); | ||||||
|     size_t data_offset = strides[axis] * sizes[i]; |     size_t data_offset = strides[axis] * sizes[i]; | ||||||
| @@ -80,6 +81,7 @@ array compute_dynamic_offset( | |||||||
|     return std::make_tuple(false, std::move(source), std::vector{kernel_name}); |     return std::make_tuple(false, std::move(source), std::vector{kernel_name}); | ||||||
|   }); |   }); | ||||||
|  |  | ||||||
|  |   auto& encoder = cu::get_command_encoder(s); | ||||||
|   // Prepare output. |   // Prepare output. | ||||||
|   array offset({1}, int64, nullptr, {}); |   array offset({1}, int64, nullptr, {}); | ||||||
|   bool donate = indices.is_donatable() && |   bool donate = indices.is_donatable() && | ||||||
| @@ -87,10 +89,9 @@ array compute_dynamic_offset( | |||||||
|   if (donate) { |   if (donate) { | ||||||
|     offset.copy_shared_buffer(indices); |     offset.copy_shared_buffer(indices); | ||||||
|   } else { |   } else { | ||||||
|     offset.set_data(allocator::malloc(offset.itemsize())); |     offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream())); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   auto& encoder = cu::get_command_encoder(s); |  | ||||||
|   encoder.add_temporary(offset); |   encoder.add_temporary(offset); | ||||||
|   encoder.set_input_array(indices); |   encoder.set_input_array(indices); | ||||||
|   encoder.set_output_array(offset); |   encoder.set_output_array(offset); | ||||||
|   | |||||||
| @@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   nvtx3::scoped_range r("Softmax::eval_gpu"); |   nvtx3::scoped_range r("Softmax::eval_gpu"); | ||||||
|   assert(inputs.size() == 1); |   assert(inputs.size() == 1); | ||||||
|   auto& s = stream(); |   auto& s = stream(); | ||||||
|  |   auto& encoder = cu::get_command_encoder(s); | ||||||
|  |  | ||||||
|   // Make sure that the last dimension is contiguous. |   // Make sure that the last dimension is contiguous. | ||||||
|   auto set_output = [&s, &out](const array& x) { |   auto set_output = [&s, &out, &encoder](const array& x) { | ||||||
|     if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { |     if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { | ||||||
|       if (x.is_donatable()) { |       if (x.is_donatable()) { | ||||||
|         out.copy_shared_buffer(x); |         out.copy_shared_buffer(x); | ||||||
|       } else { |       } else { | ||||||
|         out.set_data( |         out.set_data( | ||||||
|             allocator::malloc(x.data_size() * x.itemsize()), |             cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), | ||||||
|             x.data_size(), |             x.data_size(), | ||||||
|             x.strides(), |             x.strides(), | ||||||
|             x.flags()); |             x.flags()); | ||||||
| @@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   int axis_size = in.shape().back(); |   int axis_size = in.shape().back(); | ||||||
|   int n_rows = in.data_size() / axis_size; |   int n_rows = in.data_size() / axis_size; | ||||||
|  |  | ||||||
|   auto& encoder = cu::get_command_encoder(s); |  | ||||||
|   encoder.set_input_array(in); |   encoder.set_input_array(in); | ||||||
|   encoder.set_output_array(out); |   encoder.set_output_array(out); | ||||||
|   dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { |   dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { | ||||||
|   | |||||||
| @@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { | |||||||
|     array trans = swapaxes_in_eval(in, axis, last_dim); |     array trans = swapaxes_in_eval(in, axis, last_dim); | ||||||
|     in = contiguous_copy_gpu(trans, s); |     in = contiguous_copy_gpu(trans, s); | ||||||
|     encoder.add_temporary(in); |     encoder.add_temporary(in); | ||||||
|     out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); |     out = array( | ||||||
|  |         cu::malloc_async(out.nbytes(), encoder.stream()), | ||||||
|  |         in.shape(), | ||||||
|  |         out.dtype()); | ||||||
|     encoder.add_temporary(out); |     encoder.add_temporary(out); | ||||||
|   } else { |   } else { | ||||||
|     out.set_data( |     out.set_data( | ||||||
|         allocator::malloc(in.data_size() * out.itemsize()), |         cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), | ||||||
|         in.data_size(), |         in.data_size(), | ||||||
|         in.strides(), |         in.strides(), | ||||||
|         in.flags()); |         in.flags()); | ||||||
| @@ -70,12 +73,18 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { | |||||||
|           thrust::make_counting_iterator(0), OffsetTransform{nsort}); |           thrust::make_counting_iterator(0), OffsetTransform{nsort}); | ||||||
|       if (argsort) { |       if (argsort) { | ||||||
|         // Indices in the sorted dimension. |         // Indices in the sorted dimension. | ||||||
|         array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); |         array indices( | ||||||
|  |             cu::malloc_async(out.nbytes(), encoder.stream()), | ||||||
|  |             in.shape(), | ||||||
|  |             out.dtype()); | ||||||
|         encoder.add_temporary(indices); |         encoder.add_temporary(indices); | ||||||
|  |  | ||||||
|         // In argsort though we don't need the result of sorted values, the |         // In argsort though we don't need the result of sorted values, the | ||||||
|         // API requires us to provide an array to store it. |         // API requires us to provide an array to store it. | ||||||
|         array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); |         array discard( | ||||||
|  |             cu::malloc_async(in.nbytes(), encoder.stream()), | ||||||
|  |             in.shape(), | ||||||
|  |             in.dtype()); | ||||||
|         encoder.add_temporary(discard); |         encoder.add_temporary(discard); | ||||||
|  |  | ||||||
|         size_t size; |         size_t size; | ||||||
| @@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { | |||||||
|             sizeof(Type) * 8, |             sizeof(Type) * 8, | ||||||
|             stream)); |             stream)); | ||||||
|  |  | ||||||
|         array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8); |         array temp( | ||||||
|  |             cu::malloc_async(size, encoder.stream()), | ||||||
|  |             {static_cast<int>(size)}, | ||||||
|  |             uint8); | ||||||
|         encoder.add_temporary(temp); |         encoder.add_temporary(temp); | ||||||
|  |  | ||||||
|         // Start capturing after allocations |         // Start capturing after allocations | ||||||
| @@ -135,7 +147,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { | |||||||
|             sizeof(Type) * 8, |             sizeof(Type) * 8, | ||||||
|             stream)); |             stream)); | ||||||
|  |  | ||||||
|         array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8); |         array temp( | ||||||
|  |             cu::malloc_async(size, encoder.stream()), | ||||||
|  |             {static_cast<int>(size)}, | ||||||
|  |             uint8); | ||||||
|         encoder.add_temporary(temp); |         encoder.add_temporary(temp); | ||||||
|  |  | ||||||
|         // Start capturing after allocations |         // Start capturing after allocations | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun