mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
wait for tasks in cuda (#2636)
This commit is contained in:
@@ -86,7 +86,7 @@ CudaAllocator::CudaAllocator()
|
|||||||
// TODO: Set memory limit for multi-device.
|
// TODO: Set memory limit for multi-device.
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
memory_limit_ = total * 0.8;
|
memory_limit_ = total * 0.95;
|
||||||
max_pool_size_ = memory_limit_;
|
max_pool_size_ = memory_limit_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,6 @@ namespace mlx::core::cu {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
|
||||||
// This should be less than 255
|
|
||||||
constexpr int default_max_nodes_per_graph = 20;
|
|
||||||
|
|
||||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||||
|
|
||||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||||
@@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
|||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
if (!use_cuda_graphs()) {
|
if (!use_cuda_graphs()) {
|
||||||
|
enc.node_count_++;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) {
|
|||||||
active_outputs_.push_back(id);
|
active_outputs_.push_back(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::maybe_commit() {
|
|
||||||
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
|
||||||
commit();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommandEncoder::add_kernel_node(
|
void CommandEncoder::add_kernel_node(
|
||||||
void* func,
|
void* func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
@@ -234,6 +225,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
if (!use_cuda_graphs()) {
|
if (!use_cuda_graphs()) {
|
||||||
|
node_count_++;
|
||||||
CHECK_CUDA_ERROR(cudaLaunchKernel(
|
CHECK_CUDA_ERROR(cudaLaunchKernel(
|
||||||
func, grid_dim, block_dim, params, smem_bytes, stream()));
|
func, grid_dim, block_dim, params, smem_bytes, stream()));
|
||||||
return;
|
return;
|
||||||
@@ -254,6 +246,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
if (!use_cuda_graphs()) {
|
if (!use_cuda_graphs()) {
|
||||||
|
node_count_++;
|
||||||
CHECK_CUDA_ERROR(cuLaunchKernel(
|
CHECK_CUDA_ERROR(cuLaunchKernel(
|
||||||
func,
|
func,
|
||||||
grid_dim.x,
|
grid_dim.x,
|
||||||
@@ -296,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
|||||||
|
|
||||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
if (!use_cuda_graphs()) {
|
if (!use_cuda_graphs()) {
|
||||||
|
node_count_++;
|
||||||
CudaGraphExec graph_exec;
|
CudaGraphExec graph_exec;
|
||||||
graph_exec.instantiate(child);
|
graph_exec.instantiate(child);
|
||||||
device_.make_current();
|
device_.make_current();
|
||||||
@@ -307,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
|||||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int CommandEncoder::get_num_ops() {
|
||||||
|
return node_count_;
|
||||||
|
}
|
||||||
|
|
||||||
void CommandEncoder::commit() {
|
void CommandEncoder::commit() {
|
||||||
nvtx3::scoped_range r("CommandEncoder::commit");
|
nvtx3::scoped_range r("CommandEncoder::commit");
|
||||||
if (!temporaries_.empty()) {
|
if (!temporaries_.empty()) {
|
||||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||||
}
|
}
|
||||||
if (node_count_ > 0) {
|
if (use_cuda_graphs() && node_count_ > 0) {
|
||||||
if (!from_nodes_.empty()) {
|
if (!from_nodes_.empty()) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||||
graph_,
|
graph_,
|
||||||
@@ -355,7 +353,6 @@ void CommandEncoder::commit() {
|
|||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
|
||||||
// Reset state
|
// Reset state
|
||||||
node_count_ = 0;
|
|
||||||
graph_node_count_ = 0;
|
graph_node_count_ = 0;
|
||||||
empty_node_count_ = 0;
|
empty_node_count_ = 0;
|
||||||
from_nodes_.clear();
|
from_nodes_.clear();
|
||||||
@@ -367,6 +364,7 @@ void CommandEncoder::commit() {
|
|||||||
|
|
||||||
// Put completion handlers in a batch.
|
// Put completion handlers in a batch.
|
||||||
worker_.commit(stream_);
|
worker_.commit(stream_);
|
||||||
|
node_count_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::synchronize() {
|
void CommandEncoder::synchronize() {
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class CommandEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void add_completed_handler(std::function<void()> task);
|
void add_completed_handler(std::function<void()> task);
|
||||||
void maybe_commit();
|
int get_num_ops();
|
||||||
void commit();
|
void commit();
|
||||||
|
|
||||||
Device& device() {
|
Device& device() {
|
||||||
|
|||||||
@@ -5,11 +5,15 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/gpu/available.h"
|
#include "mlx/backend/gpu/available.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
namespace mlx::core::gpu {
|
namespace mlx::core::gpu {
|
||||||
|
|
||||||
|
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||||
|
constexpr int default_max_nodes_per_graph = 20;
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -36,7 +40,8 @@ void eval(array& arr) {
|
|||||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
auto& stream = arr.primitive().stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
// Keep used buffers alive until kernel finishes running.
|
// Keep used buffers alive until kernel finishes running.
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
// Except for the donated one.
|
// Except for the donated one.
|
||||||
@@ -47,7 +52,14 @@ void eval(array& arr) {
|
|||||||
for (auto& s : arr.siblings()) {
|
for (auto& s : arr.siblings()) {
|
||||||
encoder.add_temporary(s);
|
encoder.add_temporary(s);
|
||||||
}
|
}
|
||||||
encoder.maybe_commit();
|
|
||||||
|
if (encoder.get_num_ops() >=
|
||||||
|
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||||
|
scheduler::notify_new_task(stream);
|
||||||
|
encoder.add_completed_handler(
|
||||||
|
[stream]() { scheduler::notify_task_completion(stream); });
|
||||||
|
encoder.commit();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void finalize(Stream s) {
|
void finalize(Stream s) {
|
||||||
|
|||||||
Reference in New Issue
Block a user