mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Switch to CUDA graphs (#2317)
* cuda graph prototype fix signal bug + start to add dependencies capture more capture more ops remaining ops fix reduce and rope deps add concurrent context try update, but not working cosistent topology order use node api use node api directly to reduce overhead fix bug use kernels in unary cache graph format fix synchronization format * comment
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <deque>
|
||||
@@ -23,72 +24,48 @@ using KernelBuilderResult = std::pair<
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
|
||||
class JitModule {
|
||||
public:
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
~JitModule();
|
||||
struct KernelArgs {
|
||||
void** args() {
|
||||
return args_.data();
|
||||
}
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
|
||||
void append_arg(const array& a) {
|
||||
append_arg(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
||||
void append(const array& a) {
|
||||
append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(T val) {
|
||||
void append(T val) {
|
||||
storage_.emplace_back(val);
|
||||
append_ptr_arg(&storage_.back());
|
||||
append_ptr(&storage_.back());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(std::vector<T> vec) {
|
||||
void append(std::vector<T> vec) {
|
||||
if (vec.empty()) {
|
||||
// The nullptr can not be used as arg, pass something not null.
|
||||
append_arg(std::monostate{});
|
||||
append(std::monostate{});
|
||||
} else {
|
||||
append_ptr_arg(vec.data());
|
||||
append_ptr(vec.data());
|
||||
storage_.emplace_back(std::move(vec));
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure the arg is copied to an array with size of NDIM.
|
||||
template <size_t NDIM = MAX_NDIM, typename T>
|
||||
void append_ndim_arg(const std::vector<T>& vec) {
|
||||
void append_ndim(std::vector<T> vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
std::vector<T> copied(NDIM);
|
||||
std::copy(vec.begin(), vec.end(), copied.data());
|
||||
append_arg(std::move(copied));
|
||||
vec.resize(NDIM);
|
||||
append(std::move(vec));
|
||||
}
|
||||
|
||||
// Launch kernel with |kernel_name| that each thread works on
|
||||
// |work_per_thread| elements of |arr|.
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
const std::string& kernel_name,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1);
|
||||
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
CUfunction kernel,
|
||||
Dims num_blocks,
|
||||
Dims block_dims);
|
||||
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
void append_ptr(const void* v) {
|
||||
args_.push_back(const_cast<void*>(v));
|
||||
}
|
||||
|
||||
private:
|
||||
void append_ptr_arg(const void* v);
|
||||
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
@@ -105,6 +82,23 @@ class JitModule {
|
||||
std::deque<Arg> storage_;
|
||||
};
|
||||
|
||||
class JitModule {
|
||||
public:
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
};
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
|
||||
Reference in New Issue
Block a user