mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 07:03:10 +08:00
Merge c093fa72c8
into 3dcb286baf
This commit is contained in:
commit
ece087892d
@ -29,11 +29,18 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
|||||||
|
|
||||||
int cuda_graph_cache_size() {
|
int cuda_graph_cache_size() {
|
||||||
static int cache_size = []() {
|
static int cache_size = []() {
|
||||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
|
||||||
}();
|
}();
|
||||||
return cache_size;
|
return cache_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool use_cuda_graphs() {
|
||||||
|
static bool use_graphs = []() {
|
||||||
|
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
|
}();
|
||||||
|
return use_graphs;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
@ -86,11 +93,18 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
|||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
enc.device().make_current();
|
enc.device().make_current();
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
graph.end_capture(enc.stream());
|
graph.end_capture(enc.stream());
|
||||||
if (discard) {
|
if (discard) {
|
||||||
return;
|
return;
|
||||||
@ -105,6 +119,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
|||||||
|
|
||||||
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||||
enc.in_concurrent_ = false;
|
enc.in_concurrent_ = false;
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Use an empty graph node for synchronization
|
// Use an empty graph node for synchronization
|
||||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
||||||
@ -193,11 +210,18 @@ void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_input_array(const array& arr) {
|
void CommandEncoder::set_input_array(const array& arr) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||||
active_deps_.push_back(id);
|
active_deps_.push_back(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_output_array(const array& arr) {
|
void CommandEncoder::set_output_array(const array& arr) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||||
active_deps_.push_back(id);
|
active_deps_.push_back(id);
|
||||||
active_outputs_.push_back(id);
|
active_outputs_.push_back(id);
|
||||||
@ -215,6 +239,11 @@ void CommandEncoder::add_kernel_node(
|
|||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
CHECK_CUDA_ERROR(cudaLaunchKernel(
|
||||||
|
func, grid_dim, block_dim, params, smem_bytes, stream()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
cudaKernelNodeParams kernel_params = {0};
|
cudaKernelNodeParams kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
@ -230,6 +259,22 @@ void CommandEncoder::add_kernel_node(
|
|||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
CHECK_CUDA_ERROR(cuLaunchKernel(
|
||||||
|
func,
|
||||||
|
grid_dim.x,
|
||||||
|
grid_dim.y,
|
||||||
|
grid_dim.z,
|
||||||
|
block_dim.x,
|
||||||
|
block_dim.y,
|
||||||
|
block_dim.z,
|
||||||
|
smem_bytes,
|
||||||
|
stream(),
|
||||||
|
params,
|
||||||
|
nullptr));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDimX = grid_dim.x;
|
kernel_params.gridDimX = grid_dim.x;
|
||||||
@ -256,6 +301,12 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
CudaGraphExec graph_exec;
|
||||||
|
graph_exec.instantiate(child);
|
||||||
|
device_.make_current();
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
|
||||||
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||||
|
@ -76,9 +76,6 @@ class CommandEncoder {
|
|||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params);
|
void** params);
|
||||||
|
|
||||||
// Low-level graph helpers.
|
|
||||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
|
||||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
|
||||||
void add_graph_node(cudaGraph_t child);
|
void add_graph_node(cudaGraph_t child);
|
||||||
|
|
||||||
void add_temporary(const array& arr) {
|
void add_temporary(const array& arr) {
|
||||||
@ -101,6 +98,9 @@ class CommandEncoder {
|
|||||||
void synchronize();
|
void synchronize();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||||
|
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||||
|
|
||||||
struct GraphNode {
|
struct GraphNode {
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
// K = kernel
|
// K = kernel
|
||||||
|
Loading…
Reference in New Issue
Block a user