Use async cuda malloc managed with cuda 13

This commit is contained in:
Awni Hannun
2025-10-26 16:17:27 -07:00
parent 8f8af61a37
commit 58ccbaaf12
19 changed files with 110 additions and 54 deletions

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/utils.h"
@@ -93,9 +94,17 @@ CudaAllocator::CudaAllocator()
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.id = 0;
loc.type = cudaMemLocationTypeNone;
cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged);
// TODO set that.
// uint64_t threshold = UINT64_MAX;
#endif
}
Buffer CudaAllocator::malloc(size_t size) {
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
@@ -123,7 +132,12 @@ Buffer CudaAllocator::malloc(size_t size) {
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
cudaError_t err;
if (stream != nullptr && cuda_pool_ != nullptr) {
err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream);
} else {
err = cudaMallocManaged(&buf->data, size);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
@@ -141,6 +155,14 @@ Buffer CudaAllocator::malloc(size_t size) {
return Buffer{buf};
}
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr);
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
@@ -220,6 +242,16 @@ CudaAllocator& allocator() {
return *allocator_;
}
Buffer malloc_async(size_t size, cudaStream_t stream) {
auto buffer = allocator().malloc_async(size, stream);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace cu
namespace allocator {

View File

@@ -5,6 +5,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include <cuda_runtime.h>
#include <mutex>
#include <set>
#include <utility>
@@ -45,6 +46,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream);
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
@@ -58,6 +60,7 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf);
CudaAllocator();
@@ -70,8 +73,11 @@ class CudaAllocator : public allocator::Allocator {
size_t active_memory_{0};
size_t peak_memory_{0};
SmallSizePool scalar_pool_;
cudaMemPool_t cuda_pool_{nullptr};
};
CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream);
} // namespace mlx::core::cu

View File

@@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {

View File

@@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
@@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {

View File

@@ -87,8 +87,8 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_input_array(in);
encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"

View File

@@ -370,7 +370,7 @@ void CublasGemm::execute(
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(nbytes),
cu::malloc_async(nbytes, encoder.stream()),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
{batch_count * 3},
uint64);
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
{batch_count * 4},
uint64);

View File

@@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() > 0);
const auto& src = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(idx_dtype),
nidx);
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_ndim,
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}
@@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
const auto& src = inputs[0];
const auto& idx = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(out.dtype()),
dtype_to_string(idx.dtype()));
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
idx.flags().row_contiguous,
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}

View File

@@ -230,9 +230,10 @@ void LayerNorm::eval_gpu(
nvtx3::scoped_range r("LayerNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -243,7 +244,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -265,7 +266,6 @@ void LayerNorm::eval_gpu(
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(b);
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
g_in_gw = true;
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
encoder.add_temporary(gw_temp);
}
}

View File

@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
cu::malloc_async(in.nbytes() / n, encoder.stream()),
in.data_size() / n,
std::move(strides),
flags);

View File

@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
gemm_and_bias(
encoder,
M,
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);

View File

@@ -176,9 +176,10 @@ void RMSNorm::eval_gpu(
nvtx3::scoped_range r("RMSNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -189,7 +190,7 @@ void RMSNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -209,7 +210,6 @@ void RMSNorm::eval_gpu(
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_output_array(out);
@@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu(
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
encoder.add_temporary(gw_temp);
}
}

View File

@@ -250,6 +250,7 @@ void RoPE::eval_gpu(
nvtx3::scoped_range r("RoPE::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto& in = inputs[0];
auto& offset = inputs[1];
auto& out = outputs[0];
@@ -291,14 +292,14 @@ void RoPE::eval_gpu(
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
@@ -319,7 +320,6 @@ void RoPE::eval_gpu(
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(donated ? out : in);
encoder.set_input_array(offset);
if (with_freqs) {

View File

@@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback(
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
sums.set_data(allocator::malloc(sums.nbytes()));
maxs.set_data(allocator::malloc(maxs.nbytes()));
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate);
encoder.add_temporary(sums);
@@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu(
};
o.set_data(
allocator::malloc(o.nbytes()),
cu::malloc_async(o.nbytes(), encoder.stream()),
o.size(),
{str_oB, str_oH, str_oL, str_oD},
flags);

View File

@@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto in = inputs[0];
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
in.data_size(),
in.strides(),
in.flags());
@@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t axis_size = in.shape(axis_);
bool contiguous = in.strides()[axis_] == 1;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);

View File

@@ -23,14 +23,15 @@ void concatenate_gpu(
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
auto concurrent = cu::get_command_encoder(s).concurrent_context();
auto concurrent = encoder.concurrent_context();
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];
@@ -80,6 +81,7 @@ array compute_dynamic_offset(
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
});
auto& encoder = cu::get_command_encoder(s);
// Prepare output.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
@@ -87,10 +89,9 @@ array compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
}
auto& encoder = cu::get_command_encoder(s);
encoder.add_temporary(offset);
encoder.set_input_array(indices);
encoder.set_output_array(offset);

View File

@@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Softmax::eval_gpu");
assert(inputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {

View File

@@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s);
encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
out = array(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
in.data_size(),
in.strides(),
in.flags());
@@ -70,12 +73,18 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
thrust::make_counting_iterator(0), OffsetTransform{nsort});
if (argsort) {
// Indices in the sorted dimension.
array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
array indices(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
encoder.add_temporary(indices);
// In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it.
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
array discard(
cu::malloc_async(in.nbytes(), encoder.stream()),
in.shape(),
in.dtype());
encoder.add_temporary(discard);
size_t size;
@@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
@@ -135,7 +147,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp);
// Start capturing after allocations