[CUDA] Reduce use of managed memory (#2725)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled

* Use async cuda malloc managed with cuda 13

* add pool threshold

* refactor for regular cuda malloc

* load eval gpu for cuda

* remove use of cuda pool, use cuda free async

* fix

* fix

* fix

* fix

* fix + comment
This commit is contained in:
Awni Hannun
2025-11-05 16:05:23 -08:00
committed by GitHub
parent 27778156dc
commit df58b4133a
79 changed files with 795 additions and 515 deletions

View File

@@ -241,7 +241,7 @@ void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
auto* bias_ptr = gpu_ptr<void>(bias);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
@@ -278,9 +278,9 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
nullptr,
alpha);
}
@@ -321,10 +321,10 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
gpu_ptr<void>(c),
alpha,
beta);
}
@@ -370,11 +370,11 @@ void CublasGemm::execute(
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(nbytes),
cu::malloc_async(nbytes, encoder.stream()),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
workspace_ptr = gpu_ptr<void>(workspace);
}
auto capture = encoder.capture_context();

View File

@@ -25,9 +25,10 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
nullptr,
alpha);
a_it.step();
@@ -60,10 +61,11 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
alpha,
beta);
a_it.step();

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
{batch_count * 3},
uint64);
@@ -183,10 +183,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -200,10 +200,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -219,7 +219,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
{batch_count * 4},
uint64);
@@ -271,11 +271,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -290,11 +290,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -312,7 +312,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;

View File

@@ -149,13 +149,13 @@ void gemv(
auto vec_strides = const_param(b_batch_strides);
if (M == 1) {
mat = b.data<DataType>();
vec = a.data<DataType>();
mat = gpu_ptr<DataType>(b);
vec = gpu_ptr<DataType>(a);
rows = N;
std::swap(mat_strides, vec_strides);
} else {
mat = a.data<DataType>();
vec = b.data<DataType>();
mat = gpu_ptr<DataType>(a);
vec = gpu_ptr<DataType>(b);
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
@@ -177,7 +177,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols);
} else {
@@ -189,7 +189,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols,
const_param(batch_shape),