wait for tasks in cuda (#2636)

This commit is contained in:
Awni Hannun
2025-09-30 16:08:46 -07:00
committed by GitHub
parent eb24267b56
commit bbf1423953
4 changed files with 26 additions and 16 deletions

View File

@@ -86,7 +86,7 @@ CudaAllocator::CudaAllocator()
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8; memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_; max_pool_size_ = memory_limit_;
} }

View File

@@ -14,10 +14,6 @@ namespace mlx::core::cu {
namespace { namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) #define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) { void check_cudnn_error(const char* name, cudnnStatus_t err) {
@@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
enc.node_count_++;
return; return;
} }
@@ -221,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) {
active_outputs_.push_back(id); active_outputs_.push_back(id);
} }
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}
void CommandEncoder::add_kernel_node( void CommandEncoder::add_kernel_node(
void* func, void* func,
dim3 grid_dim, dim3 grid_dim,
@@ -234,6 +225,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel( CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream())); func, grid_dim, block_dim, params, smem_bytes, stream()));
return; return;
@@ -254,6 +246,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel( CHECK_CUDA_ERROR(cuLaunchKernel(
func, func,
grid_dim.x, grid_dim.x,
@@ -296,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
void CommandEncoder::add_graph_node(cudaGraph_t child) { void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec; CudaGraphExec graph_exec;
graph_exec.instantiate(child); graph_exec.instantiate(child);
device_.make_current(); device_.make_current();
@@ -307,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
insert_graph_dependencies(GraphNode{node, 'G'}); insert_graph_dependencies(GraphNode{node, 'G'});
} }
int CommandEncoder::get_num_ops() {
return node_count_;
}
void CommandEncoder::commit() { void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit"); nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) { if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {}); add_completed_handler([temporaries = std::move(temporaries_)]() {});
} }
if (node_count_ > 0) { if (use_cuda_graphs() && node_count_ > 0) {
if (!from_nodes_.empty()) { if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies( CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, graph_,
@@ -355,7 +353,6 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state // Reset state
node_count_ = 0;
graph_node_count_ = 0; graph_node_count_ = 0;
empty_node_count_ = 0; empty_node_count_ = 0;
from_nodes_.clear(); from_nodes_.clear();
@@ -367,6 +364,7 @@ void CommandEncoder::commit() {
// Put completion handlers in a batch. // Put completion handlers in a batch.
worker_.commit(stream_); worker_.commit(stream_);
node_count_ = 0;
} }
void CommandEncoder::synchronize() { void CommandEncoder::synchronize() {

View File

@@ -83,7 +83,7 @@ class CommandEncoder {
} }
void add_completed_handler(std::function<void()> task); void add_completed_handler(std::function<void()> task);
void maybe_commit(); int get_num_ops();
void commit(); void commit();
Device& device() { Device& device() {

View File

@@ -5,11 +5,15 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h" #include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu { namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() { bool is_available() {
return true; return true;
} }
@@ -36,7 +40,8 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs); arr.primitive().eval_gpu(arr.inputs(), outputs);
} }
auto& encoder = cu::get_command_encoder(arr.primitive().stream()); auto& stream = arr.primitive().stream();
auto& encoder = cu::get_command_encoder(stream);
// Keep used buffers alive until kernel finishes running. // Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
// Except for the donated one. // Except for the donated one.
@@ -47,7 +52,14 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) { for (auto& s : arr.siblings()) {
encoder.add_temporary(s); encoder.add_temporary(s);
} }
encoder.maybe_commit();
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
} }
void finalize(Stream s) { void finalize(Stream s) {