[CUDA] Switch to CUDA graphs (#2317)

* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
This commit is contained in:
Awni Hannun
2025-07-02 15:59:13 -07:00
committed by GitHub
parent e76e9b87f0
commit ec0d5db67b
36 changed files with 1461 additions and 1212 deletions

View File

@@ -7,41 +7,108 @@
#include "mlx/stream.h"
#include <cublasLt.h>
#include <cuda.h>
#include <thrust/execution_policy.h>
#include <unordered_map>
namespace mlx::core::cu {
class Device;
class CommandEncoder;
class DeviceStream {
class CommandEncoder {
public:
explicit DeviceStream(Device& device);
struct CaptureContext {
CaptureContext(CommandEncoder& enc);
~CaptureContext();
cudaGraph_t graph;
CommandEncoder& enc;
};
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc);
~ConcurrentContext();
CommandEncoder& enc;
};
DeviceStream(const DeviceStream&) = delete;
DeviceStream& operator=(const DeviceStream&) = delete;
explicit CommandEncoder(Device& d);
~CommandEncoder();
// Wait until kernels in the stream complete.
void synchronize();
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
// Return a cuda stream for launching kernels.
cudaStream_t schedule_cuda_stream();
// Return the last cuda stream used.
cudaStream_t last_cuda_stream();
CommandEncoder& get_encoder();
Device& device() {
return device_;
CaptureContext capture_context() {
return CaptureContext{*this};
}
ConcurrentContext concurrent_context() {
return ConcurrentContext{*this};
}
void set_input_array(const array& arr);
void set_output_array(const array& arr);
template <typename F, typename... Params>
void
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
constexpr size_t num = sizeof...(Params);
void* ptrs[num];
size_t i = 0;
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)),
...);
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
}
void add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
void** params);
void
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void maybe_commit();
void commit();
CudaStream& stream() {
return stream_;
}
// Wait until kernels and completion handlers are finished
void synchronize();
private:
Device& device_;
struct GraphNode {
cudaGraphNode_t node;
// K = kernel
// E = empty
// G = subgraph
char node_type;
std::string id;
};
void insert_graph_dependencies(GraphNode node);
void insert_graph_dependencies(std::vector<GraphNode> nodes);
CudaStream stream_;
std::unique_ptr<CommandEncoder> encoder_;
cudaGraph_t graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};
char empty_node_count_{0};
bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_;
std::vector<cudaGraphNode_t> to_nodes_;
std::string graph_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
};
class Device {
@@ -55,7 +122,7 @@ class Device {
// Make this device the current cuda device, required by some cuda calls.
void make_current();
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s);
int cuda_device() const {
return device_;
@@ -75,67 +142,10 @@ class Device {
int compute_capability_major_;
int compute_capability_minor_;
cublasLtHandle_t lt_;
std::unordered_map<int, DeviceStream> streams_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a cuda stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(cudaStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_cuda_error("kernel launch", cudaGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
// Wait until kernels and completion handlers are finished
void synchronize();
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
std::unordered_map<int, CommandEncoder> encoders_;
};
Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s);
// Return an execution policy that does not sync for result.