mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 18:26:41 +08:00
enable cuda graph toggle
This commit is contained in:
parent
333ffea273
commit
4ba4544549
@ -34,6 +34,14 @@ int cuda_graph_cache_size() {
|
|||||||
return cache_size;
|
return cache_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool use_cuda_graphs() {
|
||||||
|
static bool use_graphs = []() {
|
||||||
|
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
|
}();
|
||||||
|
return use_graphs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
@ -86,11 +94,18 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
|||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
enc.device().make_current();
|
enc.device().make_current();
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
graph.end_capture(enc.stream());
|
graph.end_capture(enc.stream());
|
||||||
if (discard) {
|
if (discard) {
|
||||||
return;
|
return;
|
||||||
@ -105,6 +120,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
|||||||
|
|
||||||
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||||
enc.in_concurrent_ = false;
|
enc.in_concurrent_ = false;
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Use an empty graph node for synchronization
|
// Use an empty graph node for synchronization
|
||||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
||||||
@ -193,11 +211,18 @@ void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_input_array(const array& arr) {
|
void CommandEncoder::set_input_array(const array& arr) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||||
active_deps_.push_back(id);
|
active_deps_.push_back(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_output_array(const array& arr) {
|
void CommandEncoder::set_output_array(const array& arr) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||||
active_deps_.push_back(id);
|
active_deps_.push_back(id);
|
||||||
active_outputs_.push_back(id);
|
active_outputs_.push_back(id);
|
||||||
@ -215,6 +240,17 @@ void CommandEncoder::add_kernel_node(
|
|||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
CHECK_CUDA_ERROR(cudaLaunchKernel(
|
||||||
|
func,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
params,
|
||||||
|
smem_bytes,
|
||||||
|
stream()
|
||||||
|
));
|
||||||
|
return;
|
||||||
|
}
|
||||||
cudaKernelNodeParams kernel_params = {0};
|
cudaKernelNodeParams kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
@ -230,6 +266,23 @@ void CommandEncoder::add_kernel_node(
|
|||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
CHECK_CUDA_ERROR(cuLaunchKernel(
|
||||||
|
func,
|
||||||
|
grid_dim.x,
|
||||||
|
grid_dim.y,
|
||||||
|
grid_dim.z,
|
||||||
|
block_dim.x,
|
||||||
|
block_dim.y,
|
||||||
|
block_dim.z,
|
||||||
|
smem_bytes,
|
||||||
|
stream(),
|
||||||
|
params,
|
||||||
|
nullptr
|
||||||
|
));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDimX = grid_dim.x;
|
kernel_params.gridDimX = grid_dim.x;
|
||||||
@ -256,6 +309,12 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
CudaGraphExec graph_exec;
|
||||||
|
graph_exec.instantiate(child);
|
||||||
|
device_.make_current();
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
|
||||||
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||||
|
@ -76,9 +76,6 @@ class CommandEncoder {
|
|||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params);
|
void** params);
|
||||||
|
|
||||||
// Low-level graph helpers.
|
|
||||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
|
||||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
|
||||||
void add_graph_node(cudaGraph_t child);
|
void add_graph_node(cudaGraph_t child);
|
||||||
|
|
||||||
void add_temporary(const array& arr) {
|
void add_temporary(const array& arr) {
|
||||||
@ -101,6 +98,9 @@ class CommandEncoder {
|
|||||||
void synchronize();
|
void synchronize();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||||
|
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||||
|
|
||||||
struct GraphNode {
|
struct GraphNode {
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
// K = kernel
|
// K = kernel
|
||||||
|
Loading…
Reference in New Issue
Block a user