wait for tasks in cuda (#2636)

This commit is contained in:
Awni Hannun
2025-09-30 16:08:46 -07:00
committed by GitHub
parent eb24267b56
commit bbf1423953
4 changed files with 26 additions and 16 deletions

View File

@@ -86,7 +86,7 @@ CudaAllocator::CudaAllocator()
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
}

View File

@@ -14,10 +14,6 @@ namespace mlx::core::cu {
namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
@@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
enc.node_count_++;
return;
}
@@ -221,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) {
active_outputs_.push_back(id);
}
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}
void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
@@ -234,6 +225,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
@@ -254,6 +246,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
@@ -296,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
@@ -307,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
insert_graph_dependencies(GraphNode{node, 'G'});
}
int CommandEncoder::get_num_ops() {
return node_count_;
}
void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (node_count_ > 0) {
if (use_cuda_graphs() && node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_,
@@ -355,7 +353,6 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
@@ -367,6 +364,7 @@ void CommandEncoder::commit() {
// Put completion handlers in a batch.
worker_.commit(stream_);
node_count_ = 0;
}
void CommandEncoder::synchronize() {

View File

@@ -83,7 +83,7 @@ class CommandEncoder {
}
void add_completed_handler(std::function<void()> task);
void maybe_commit();
int get_num_ops();
void commit();
Device& device() {

View File

@@ -5,11 +5,15 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() {
return true;
}
@@ -36,7 +40,8 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
auto& stream = arr.primitive().stream();
auto& encoder = cu::get_command_encoder(stream);
// Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) {
// Except for the donated one.
@@ -47,7 +52,14 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) {
encoder.add_temporary(s);
}
encoder.maybe_commit();
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
}
void finalize(Stream s) {