mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Use async cuda malloc managed with cuda 13
This commit is contained in:
		| @@ -1,6 +1,7 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
|  | ||||
| #include "mlx/backend/cuda/allocator.h" | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/backend/cuda/utils.h" | ||||
| #include "mlx/utils.h" | ||||
|  | ||||
| @@ -93,9 +94,17 @@ CudaAllocator::CudaAllocator() | ||||
|   CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); | ||||
|   memory_limit_ = total * 0.95; | ||||
|   max_pool_size_ = memory_limit_; | ||||
| #if CUDART_VERSION >= 13000 | ||||
|   cudaMemLocation loc; | ||||
|   loc.id = 0; | ||||
|   loc.type = cudaMemLocationTypeNone; | ||||
|   cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged); | ||||
|   // TODO set that. | ||||
|   // uint64_t threshold = UINT64_MAX; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| Buffer CudaAllocator::malloc(size_t size) { | ||||
| Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { | ||||
|   // Find available buffer from cache. | ||||
|   auto orig_size = size; | ||||
|   std::unique_lock lock(mutex_); | ||||
| @@ -123,7 +132,12 @@ Buffer CudaAllocator::malloc(size_t size) { | ||||
|     lock.unlock(); | ||||
|     if (!buf) { | ||||
|       buf = new CudaBuffer{nullptr, size}; | ||||
|       cudaError_t err = cudaMallocManaged(&buf->data, size); | ||||
|       cudaError_t err; | ||||
|       if (stream != nullptr && cuda_pool_ != nullptr) { | ||||
|         err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream); | ||||
|       } else { | ||||
|         err = cudaMallocManaged(&buf->data, size); | ||||
|       } | ||||
|       if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { | ||||
|         throw std::runtime_error(fmt::format( | ||||
|             "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); | ||||
| @@ -141,6 +155,14 @@ Buffer CudaAllocator::malloc(size_t size) { | ||||
|   return Buffer{buf}; | ||||
| } | ||||
|  | ||||
| Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) { | ||||
|   return malloc_impl(size, stream); | ||||
| } | ||||
|  | ||||
| Buffer CudaAllocator::malloc(size_t size) { | ||||
|   return malloc_impl(size, nullptr); | ||||
| } | ||||
|  | ||||
| void CudaAllocator::free(Buffer buffer) { | ||||
|   auto* buf = static_cast<CudaBuffer*>(buffer.ptr()); | ||||
|   if (!buf) { | ||||
| @@ -220,6 +242,16 @@ CudaAllocator& allocator() { | ||||
|   return *allocator_; | ||||
| } | ||||
|  | ||||
| Buffer malloc_async(size_t size, cudaStream_t stream) { | ||||
|   auto buffer = allocator().malloc_async(size, stream); | ||||
|   if (size && !buffer.ptr()) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[malloc_async] Unable to allocate " << size << " bytes."; | ||||
|     throw std::runtime_error(msg.str()); | ||||
|   } | ||||
|   return buffer; | ||||
| } | ||||
|  | ||||
| } // namespace cu | ||||
|  | ||||
| namespace allocator { | ||||
|   | ||||
| @@ -5,6 +5,7 @@ | ||||
| #include "mlx/allocator.h" | ||||
| #include "mlx/backend/common/buffer_cache.h" | ||||
|  | ||||
| #include <cuda_runtime.h> | ||||
| #include <mutex> | ||||
| #include <set> | ||||
| #include <utility> | ||||
| @@ -45,6 +46,7 @@ class SmallSizePool { | ||||
| class CudaAllocator : public allocator::Allocator { | ||||
|  public: | ||||
|   Buffer malloc(size_t size) override; | ||||
|   Buffer malloc_async(size_t size, cudaStream_t stream); | ||||
|   void free(Buffer buffer) override; | ||||
|   size_t size(Buffer buffer) const override; | ||||
|  | ||||
| @@ -58,6 +60,7 @@ class CudaAllocator : public allocator::Allocator { | ||||
|   void clear_cache(); | ||||
|  | ||||
|  private: | ||||
|   Buffer malloc_impl(size_t size, cudaStream_t stream); | ||||
|   void cuda_free(CudaBuffer* buf); | ||||
|  | ||||
|   CudaAllocator(); | ||||
| @@ -70,8 +73,11 @@ class CudaAllocator : public allocator::Allocator { | ||||
|   size_t active_memory_{0}; | ||||
|   size_t peak_memory_{0}; | ||||
|   SmallSizePool scalar_pool_; | ||||
|   cudaMemPool_t cuda_pool_{nullptr}; | ||||
| }; | ||||
|  | ||||
| CudaAllocator& allocator(); | ||||
|  | ||||
| Buffer malloc_async(size_t size, cudaStream_t stream); | ||||
|  | ||||
| } // namespace mlx::core::cu | ||||
|   | ||||
| @@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   if (out.size() == 0) { | ||||
|     return; | ||||
|   } | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(stream()); | ||||
|   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|   encoder.set_output_array(out); | ||||
|  | ||||
|   dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { | ||||
|   | ||||
| @@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   nvtx3::scoped_range r("ArgReduce::eval_gpu"); | ||||
|   assert(inputs.size() == 1); | ||||
|   auto& in = inputs[0]; | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|  | ||||
|   auto& s = stream(); | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|  | ||||
|   // Prepare the shapes, strides and axis arguments. | ||||
|   Shape shape = remove_index(in.shape(), axis_); | ||||
| @@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   int32_t ndim = shape.size(); | ||||
|  | ||||
|   // ArgReduce. | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   encoder.set_input_array(in); | ||||
|   encoder.set_output_array(out); | ||||
|   dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { | ||||
|   | ||||
| @@ -87,8 +87,8 @@ void fill_gpu(const array& in, array& out, const Stream& s) { | ||||
|   if (out.size() == 0) { | ||||
|     return; | ||||
|   } | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|   encoder.set_input_array(in); | ||||
|   encoder.set_output_array(out); | ||||
|   copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "mlx/array.h" | ||||
| #include "mlx/backend/cuda/allocator.h" | ||||
| #include "mlx/backend/cuda/lru_cache.h" | ||||
| #include "mlx/backend/cuda/worker.h" | ||||
| #include "mlx/stream.h" | ||||
|   | ||||
| @@ -370,7 +370,7 @@ void CublasGemm::execute( | ||||
|     // Ensure workspace is 256-byte aligned | ||||
|     int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; | ||||
|     array workspace( | ||||
|         allocator::malloc(nbytes), | ||||
|         cu::malloc_async(nbytes, encoder.stream()), | ||||
|         {static_cast<int>(heuristic_.workspaceSize)}, | ||||
|         int8); | ||||
|     encoder.add_temporary(workspace); | ||||
|   | ||||
| @@ -163,7 +163,7 @@ void CublasGemm::run_batched( | ||||
|  | ||||
|   // Launch kernel to set device offsets | ||||
|   auto pointers = array( | ||||
|       allocator::malloc(batch_count * sizeof(void*) * 3), | ||||
|       cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()), | ||||
|       {batch_count * 3}, | ||||
|       uint64); | ||||
|  | ||||
| @@ -251,7 +251,7 @@ void CublasGemm::run_batched( | ||||
|  | ||||
|   // Launch kernel to set device offsets | ||||
|   auto pointers = array( | ||||
|       allocator::malloc(batch_count * sizeof(uint64_t) * 4), | ||||
|       cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()), | ||||
|       {batch_count * 4}, | ||||
|       uint64); | ||||
|  | ||||
|   | ||||
| @@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   assert(inputs.size() > 0); | ||||
|   const auto& src = inputs[0]; | ||||
|  | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|   auto& s = stream(); | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|   if (out.size() == 0) { | ||||
|     return; | ||||
|   } | ||||
| @@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       dtype_to_string(idx_dtype), | ||||
|       nidx); | ||||
|  | ||||
|   auto& s = stream(); | ||||
|   cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { | ||||
|     std::vector<std::string> kernel_names; | ||||
|     for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { | ||||
| @@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       idx_ndim, | ||||
|       large ? "int64_t" : "int32_t"); | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   for (const auto& in : inputs) { | ||||
|     encoder.set_input_array(in); | ||||
|   } | ||||
| @@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   const auto& src = inputs[0]; | ||||
|   const auto& idx = inputs[1]; | ||||
|  | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|   auto& s = stream(); | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|   if (out.size() == 0) { | ||||
|     return; | ||||
|   } | ||||
| @@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       dtype_to_string(out.dtype()), | ||||
|       dtype_to_string(idx.dtype())); | ||||
|  | ||||
|   auto& s = stream(); | ||||
|   cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { | ||||
|     std::vector<std::string> kernel_names; | ||||
|     for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { | ||||
| @@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       idx.flags().row_contiguous, | ||||
|       large ? "int64_t" : "int32_t"); | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   for (const auto& in : inputs) { | ||||
|     encoder.set_input_array(in); | ||||
|   } | ||||
|   | ||||
| @@ -230,9 +230,10 @@ void LayerNorm::eval_gpu( | ||||
|   nvtx3::scoped_range r("LayerNorm::eval_gpu"); | ||||
|   auto& s = stream(); | ||||
|   auto& out = outputs[0]; | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|  | ||||
|   // Make sure that the last dimension is contiguous. | ||||
|   auto set_output = [&s, &out](const array& x) { | ||||
|   auto set_output = [&s, &out, &encoder](const array& x) { | ||||
|     bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; | ||||
|     if (no_copy && x.ndim() > 1) { | ||||
|       auto s = x.strides()[x.ndim() - 2]; | ||||
| @@ -243,7 +244,7 @@ void LayerNorm::eval_gpu( | ||||
|         out.copy_shared_buffer(x); | ||||
|       } else { | ||||
|         out.set_data( | ||||
|             allocator::malloc(x.data_size() * x.itemsize()), | ||||
|             cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), | ||||
|             x.data_size(), | ||||
|             x.strides(), | ||||
|             x.flags()); | ||||
| @@ -265,7 +266,6 @@ void LayerNorm::eval_gpu( | ||||
|   int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; | ||||
|   int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   encoder.set_input_array(x); | ||||
|   encoder.set_input_array(w); | ||||
|   encoder.set_input_array(b); | ||||
| @@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu( | ||||
|     gx.copy_shared_buffer(g); | ||||
|     g_in_gx = true; | ||||
|   } else { | ||||
|     gx.set_data(allocator::malloc(gx.nbytes())); | ||||
|     gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); | ||||
|   } | ||||
|   if (g_copied && !g_in_gx) { | ||||
|     encoder.add_temporary(g); | ||||
| @@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu( | ||||
|       g_in_gw = true; | ||||
|       gw_temp.copy_shared_buffer(g); | ||||
|     } else { | ||||
|       gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); | ||||
|       gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); | ||||
|       encoder.add_temporary(gw_temp); | ||||
|     } | ||||
|   } | ||||
|   | ||||
| @@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|  | ||||
|   auto in = ensure_contiguous(inputs[0]); | ||||
|   if (in.flags().row_contiguous) { | ||||
|     out.set_data(allocator::malloc(out.nbytes())); | ||||
|     out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|   } else { | ||||
|     auto n = in.shape(-1); | ||||
|     auto flags = in.flags(); | ||||
| @@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     } | ||||
|     flags.col_contiguous = col_contig; | ||||
|     out.set_data( | ||||
|         allocator::malloc(in.nbytes() / n), | ||||
|         cu::malloc_async(in.nbytes() / n, encoder.stream()), | ||||
|         in.data_size() / n, | ||||
|         std::move(strides), | ||||
|         flags); | ||||
|   | ||||
| @@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|  | ||||
|   int M = a_pre.shape(-2); | ||||
|   int N = b_pre.shape(-1); | ||||
| @@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|  | ||||
|   if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 && | ||||
|       c.data_size() == out.shape(-1)) { | ||||
|     out.set_data(allocator::malloc(out.nbytes())); | ||||
|     out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|     gemm_and_bias( | ||||
|         encoder, | ||||
|         M, | ||||
| @@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     auto sty = c.strides()[c.ndim() - 1]; | ||||
|     if (sty == 1 && stx == c.shape(-1)) { | ||||
|       ldc = stx; | ||||
|       out.set_data(allocator::malloc(out.nbytes())); | ||||
|       out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|     } else if (sty == 1 && stx == 0) { | ||||
|       ldc = 0; | ||||
|       out.set_data(allocator::malloc(out.nbytes())); | ||||
|       out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|     } else { | ||||
|       // Copy C into out and set C to out | ||||
|       ldc = c.shape(-1); | ||||
|   | ||||
| @@ -176,9 +176,10 @@ void RMSNorm::eval_gpu( | ||||
|   nvtx3::scoped_range r("RMSNorm::eval_gpu"); | ||||
|   auto& s = stream(); | ||||
|   auto& out = outputs[0]; | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|  | ||||
|   // Make sure that the last dimension is contiguous. | ||||
|   auto set_output = [&s, &out](const array& x) { | ||||
|   auto set_output = [&s, &out, &encoder](const array& x) { | ||||
|     bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; | ||||
|     if (no_copy && x.ndim() > 1) { | ||||
|       auto s = x.strides()[x.ndim() - 2]; | ||||
| @@ -189,7 +190,7 @@ void RMSNorm::eval_gpu( | ||||
|         out.copy_shared_buffer(x); | ||||
|       } else { | ||||
|         out.set_data( | ||||
|             allocator::malloc(x.data_size() * x.itemsize()), | ||||
|             cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), | ||||
|             x.data_size(), | ||||
|             x.strides(), | ||||
|             x.flags()); | ||||
| @@ -209,7 +210,6 @@ void RMSNorm::eval_gpu( | ||||
|   int32_t n_rows = x.data_size() / axis_size; | ||||
|   int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   encoder.set_input_array(x); | ||||
|   encoder.set_input_array(w); | ||||
|   encoder.set_output_array(out); | ||||
| @@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu( | ||||
|     gx.copy_shared_buffer(g); | ||||
|     g_in_gx = true; | ||||
|   } else { | ||||
|     gx.set_data(allocator::malloc(gx.nbytes())); | ||||
|     gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); | ||||
|   } | ||||
|   if (g_copied && !g_in_gx) { | ||||
|     encoder.add_temporary(g); | ||||
| @@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu( | ||||
|     if (!g_in_gx && donate_g) { | ||||
|       gw_temp.copy_shared_buffer(g); | ||||
|     } else { | ||||
|       gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); | ||||
|       gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); | ||||
|       encoder.add_temporary(gw_temp); | ||||
|     } | ||||
|   } | ||||
|   | ||||
| @@ -250,6 +250,7 @@ void RoPE::eval_gpu( | ||||
|   nvtx3::scoped_range r("RoPE::eval_gpu"); | ||||
|  | ||||
|   auto& s = stream(); | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   auto& in = inputs[0]; | ||||
|   auto& offset = inputs[1]; | ||||
|   auto& out = outputs[0]; | ||||
| @@ -291,14 +292,14 @@ void RoPE::eval_gpu( | ||||
|       donated = true; | ||||
|       out.copy_shared_buffer(in); | ||||
|     } else { | ||||
|       out.set_data(allocator::malloc(out.nbytes())); | ||||
|       out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|     } | ||||
|     strides[0] = mat_size; | ||||
|     strides[1] = in.strides()[ndim - 2]; | ||||
|     strides[2] = in.strides()[ndim - 1]; | ||||
|   } else if (dispatch_ndim == 3) { | ||||
|     // Handle non-contiguous 3D inputs | ||||
|     out.set_data(allocator::malloc(out.nbytes())); | ||||
|     out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|     strides[0] = in.strides()[ndim - 3]; | ||||
|     strides[1] = in.strides()[ndim - 2]; | ||||
|     strides[2] = in.strides()[ndim - 1]; | ||||
| @@ -319,7 +320,6 @@ void RoPE::eval_gpu( | ||||
|   bool single = in.flags().row_contiguous && B == 1 && T == 1; | ||||
|   bool with_freqs = inputs.size() == 3; | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   encoder.set_input_array(donated ? out : in); | ||||
|   encoder.set_input_array(offset); | ||||
|   if (with_freqs) { | ||||
|   | ||||
| @@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback( | ||||
|   array sums(intermediate_shape, float32, nullptr, {}); | ||||
|   array maxs(std::move(intermediate_shape), float32, nullptr, {}); | ||||
|  | ||||
|   intermediate.set_data(allocator::malloc(intermediate.nbytes())); | ||||
|   sums.set_data(allocator::malloc(sums.nbytes())); | ||||
|   maxs.set_data(allocator::malloc(maxs.nbytes())); | ||||
|   intermediate.set_data( | ||||
|       cu::malloc_async(intermediate.nbytes(), encoder.stream())); | ||||
|   sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream())); | ||||
|   maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream())); | ||||
|  | ||||
|   encoder.add_temporary(intermediate); | ||||
|   encoder.add_temporary(sums); | ||||
| @@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|       }; | ||||
|  | ||||
|       o.set_data( | ||||
|           allocator::malloc(o.nbytes()), | ||||
|           cu::malloc_async(o.nbytes(), encoder.stream()), | ||||
|           o.size(), | ||||
|           {str_oB, str_oH, str_oL, str_oD}, | ||||
|           flags); | ||||
|   | ||||
| @@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   assert(inputs.size() == 1); | ||||
|   auto in = inputs[0]; | ||||
|   auto& s = stream(); | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|  | ||||
|   if (in.flags().contiguous && in.strides()[axis_] != 0) { | ||||
|     if (in.is_donatable() && in.itemsize() == out.itemsize()) { | ||||
|       out.copy_shared_buffer(in); | ||||
|     } else { | ||||
|       out.set_data( | ||||
|           allocator::malloc(in.data_size() * out.itemsize()), | ||||
|           cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), | ||||
|           in.data_size(), | ||||
|           in.strides(), | ||||
|           in.flags()); | ||||
| @@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   int32_t axis_size = in.shape(axis_); | ||||
|   bool contiguous = in.strides()[axis_] == 1; | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   encoder.set_input_array(in); | ||||
|   encoder.set_output_array(out); | ||||
|  | ||||
|   | ||||
| @@ -23,14 +23,15 @@ void concatenate_gpu( | ||||
|   } | ||||
|   std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); | ||||
|  | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); | ||||
|  | ||||
|   auto strides = out.strides(); | ||||
|   auto flags = out.flags(); | ||||
|   flags.row_contiguous = false; | ||||
|   flags.col_contiguous = false; | ||||
|   flags.contiguous = false; | ||||
|   auto concurrent = cu::get_command_encoder(s).concurrent_context(); | ||||
|   auto concurrent = encoder.concurrent_context(); | ||||
|   for (int i = 0; i < inputs.size(); i++) { | ||||
|     array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); | ||||
|     size_t data_offset = strides[axis] * sizes[i]; | ||||
| @@ -80,6 +81,7 @@ array compute_dynamic_offset( | ||||
|     return std::make_tuple(false, std::move(source), std::vector{kernel_name}); | ||||
|   }); | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   // Prepare output. | ||||
|   array offset({1}, int64, nullptr, {}); | ||||
|   bool donate = indices.is_donatable() && | ||||
| @@ -87,10 +89,9 @@ array compute_dynamic_offset( | ||||
|   if (donate) { | ||||
|     offset.copy_shared_buffer(indices); | ||||
|   } else { | ||||
|     offset.set_data(allocator::malloc(offset.itemsize())); | ||||
|     offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream())); | ||||
|   } | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   encoder.add_temporary(offset); | ||||
|   encoder.set_input_array(indices); | ||||
|   encoder.set_output_array(offset); | ||||
|   | ||||
| @@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   nvtx3::scoped_range r("Softmax::eval_gpu"); | ||||
|   assert(inputs.size() == 1); | ||||
|   auto& s = stream(); | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|  | ||||
|   // Make sure that the last dimension is contiguous. | ||||
|   auto set_output = [&s, &out](const array& x) { | ||||
|   auto set_output = [&s, &out, &encoder](const array& x) { | ||||
|     if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { | ||||
|       if (x.is_donatable()) { | ||||
|         out.copy_shared_buffer(x); | ||||
|       } else { | ||||
|         out.set_data( | ||||
|             allocator::malloc(x.data_size() * x.itemsize()), | ||||
|             cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), | ||||
|             x.data_size(), | ||||
|             x.strides(), | ||||
|             x.flags()); | ||||
| @@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   int axis_size = in.shape().back(); | ||||
|   int n_rows = in.data_size() / axis_size; | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(s); | ||||
|   encoder.set_input_array(in); | ||||
|   encoder.set_output_array(out); | ||||
|   dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { | ||||
|   | ||||
| @@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { | ||||
|     array trans = swapaxes_in_eval(in, axis, last_dim); | ||||
|     in = contiguous_copy_gpu(trans, s); | ||||
|     encoder.add_temporary(in); | ||||
|     out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); | ||||
|     out = array( | ||||
|         cu::malloc_async(out.nbytes(), encoder.stream()), | ||||
|         in.shape(), | ||||
|         out.dtype()); | ||||
|     encoder.add_temporary(out); | ||||
|   } else { | ||||
|     out.set_data( | ||||
|         allocator::malloc(in.data_size() * out.itemsize()), | ||||
|         cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), | ||||
|         in.data_size(), | ||||
|         in.strides(), | ||||
|         in.flags()); | ||||
| @@ -70,12 +73,18 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { | ||||
|           thrust::make_counting_iterator(0), OffsetTransform{nsort}); | ||||
|       if (argsort) { | ||||
|         // Indices in the sorted dimension. | ||||
|         array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); | ||||
|         array indices( | ||||
|             cu::malloc_async(out.nbytes(), encoder.stream()), | ||||
|             in.shape(), | ||||
|             out.dtype()); | ||||
|         encoder.add_temporary(indices); | ||||
|  | ||||
|         // In argsort though we don't need the result of sorted values, the | ||||
|         // API requires us to provide an array to store it. | ||||
|         array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); | ||||
|         array discard( | ||||
|             cu::malloc_async(in.nbytes(), encoder.stream()), | ||||
|             in.shape(), | ||||
|             in.dtype()); | ||||
|         encoder.add_temporary(discard); | ||||
|  | ||||
|         size_t size; | ||||
| @@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { | ||||
|             sizeof(Type) * 8, | ||||
|             stream)); | ||||
|  | ||||
|         array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8); | ||||
|         array temp( | ||||
|             cu::malloc_async(size, encoder.stream()), | ||||
|             {static_cast<int>(size)}, | ||||
|             uint8); | ||||
|         encoder.add_temporary(temp); | ||||
|  | ||||
|         // Start capturing after allocations | ||||
| @@ -135,7 +147,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { | ||||
|             sizeof(Type) * 8, | ||||
|             stream)); | ||||
|  | ||||
|         array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8); | ||||
|         array temp( | ||||
|             cu::malloc_async(size, encoder.stream()), | ||||
|             {static_cast<int>(size)}, | ||||
|             uint8); | ||||
|         encoder.add_temporary(temp); | ||||
|  | ||||
|         // Start capturing after allocations | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun