mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-11 23:14:50 +08:00
wait for tasks in cuda (#2636)
This commit is contained in:
@@ -86,7 +86,7 @@ CudaAllocator::CudaAllocator()
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
memory_limit_ = total * 0.95;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
|
||||
@@ -14,10 +14,6 @@ namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
// This should be less than 255
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||
|
||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
@@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
if (!use_cuda_graphs()) {
|
||||
enc.node_count_++;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -221,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) {
|
||||
active_outputs_.push_back(id);
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_commit() {
|
||||
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
commit();
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(
|
||||
void* func,
|
||||
dim3 grid_dim,
|
||||
@@ -234,6 +225,7 @@ void CommandEncoder::add_kernel_node(
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
if (!use_cuda_graphs()) {
|
||||
node_count_++;
|
||||
CHECK_CUDA_ERROR(cudaLaunchKernel(
|
||||
func, grid_dim, block_dim, params, smem_bytes, stream()));
|
||||
return;
|
||||
@@ -254,6 +246,7 @@ void CommandEncoder::add_kernel_node(
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
if (!use_cuda_graphs()) {
|
||||
node_count_++;
|
||||
CHECK_CUDA_ERROR(cuLaunchKernel(
|
||||
func,
|
||||
grid_dim.x,
|
||||
@@ -296,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
if (!use_cuda_graphs()) {
|
||||
node_count_++;
|
||||
CudaGraphExec graph_exec;
|
||||
graph_exec.instantiate(child);
|
||||
device_.make_current();
|
||||
@@ -307,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
}
|
||||
|
||||
int CommandEncoder::get_num_ops() {
|
||||
return node_count_;
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
nvtx3::scoped_range r("CommandEncoder::commit");
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
if (node_count_ > 0) {
|
||||
if (use_cuda_graphs() && node_count_ > 0) {
|
||||
if (!from_nodes_.empty()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||
graph_,
|
||||
@@ -355,7 +353,6 @@ void CommandEncoder::commit() {
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
// Reset state
|
||||
node_count_ = 0;
|
||||
graph_node_count_ = 0;
|
||||
empty_node_count_ = 0;
|
||||
from_nodes_.clear();
|
||||
@@ -367,6 +364,7 @@ void CommandEncoder::commit() {
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.commit(stream_);
|
||||
node_count_ = 0;
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
|
||||
@@ -83,7 +83,7 @@ class CommandEncoder {
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void maybe_commit();
|
||||
int get_num_ops();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
|
||||
@@ -5,11 +5,15 @@
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
@@ -36,7 +40,8 @@ void eval(array& arr) {
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
auto& stream = arr.primitive().stream();
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
for (auto& in : arr.inputs()) {
|
||||
// Except for the donated one.
|
||||
@@ -47,7 +52,14 @@ void eval(array& arr) {
|
||||
for (auto& s : arr.siblings()) {
|
||||
encoder.add_temporary(s);
|
||||
}
|
||||
encoder.maybe_commit();
|
||||
|
||||
if (encoder.get_num_ops() >=
|
||||
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
scheduler::notify_new_task(stream);
|
||||
encoder.add_completed_handler(
|
||||
[stream]() { scheduler::notify_task_completion(stream); });
|
||||
encoder.commit();
|
||||
}
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
|
||||
Reference in New Issue
Block a user