mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
comments
This commit is contained in:
parent
fa56bf2feb
commit
d6b204b528
@ -49,7 +49,7 @@ std::string template_arguments_hash(
|
||||
}
|
||||
|
||||
std::string build_kernel(
|
||||
std::string func_name,
|
||||
const std::string& func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
@ -316,7 +316,7 @@ void CustomKernel::eval_gpu(
|
||||
name_,
|
||||
[&]() {
|
||||
return std::make_tuple(
|
||||
is_precompiled_, source_, std::vector<std::string>{kernel_name});
|
||||
is_precompiled_, source_, std::vector{kernel_name});
|
||||
},
|
||||
false);
|
||||
|
||||
|
@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() {
|
||||
bool read_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
std::vector<char>* ptx,
|
||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||
std::string& ptx,
|
||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return false;
|
||||
}
|
||||
@ -117,15 +117,15 @@ bool read_cached_ptx(
|
||||
if (!ptx_file.good()) {
|
||||
return false;
|
||||
}
|
||||
ptx->resize(ptx_size);
|
||||
ptx_file.read(ptx->data(), ptx_size);
|
||||
ptx.resize(ptx_size);
|
||||
ptx_file.read(ptx.data(), ptx_size);
|
||||
|
||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::string line;
|
||||
while (std::getline(txt_file, line)) {
|
||||
auto tab = line.find('\t');
|
||||
if (tab != std::string::npos) {
|
||||
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@ -135,7 +135,7 @@ bool read_cached_ptx(
|
||||
void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
const std::string& source_code) {
|
||||
if (cache_dir.empty()) {
|
||||
@ -222,7 +222,7 @@ void compile(
|
||||
const std::string& module_name,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& kernel_names,
|
||||
std::vector<char>& ptx,
|
||||
std::string& ptx,
|
||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
// Create the program
|
||||
nvrtcProgram prog;
|
||||
@ -282,7 +282,7 @@ void compile(
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size, 0);
|
||||
ptx.resize(ptx_size);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
@ -292,7 +292,7 @@ void compile(
|
||||
|
||||
void load_module(
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
CUmodule& module_,
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
||||
@ -322,18 +322,18 @@ JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder,
|
||||
bool cache) {
|
||||
bool use_disk_cache) {
|
||||
// Will hold the actual device executable source code and kernel names
|
||||
std::vector<char> ptx;
|
||||
std::string ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
|
||||
// Try to load them from the file cache
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
|
||||
auto [precompiled, source_code, kernel_names] = builder();
|
||||
|
||||
// Get the PTX or cubin
|
||||
if (precompiled) {
|
||||
ptx.insert(ptx.begin(), source_code.begin(), source_code.end());
|
||||
ptx = std::move(source_code);
|
||||
for (auto& name : kernel_names) {
|
||||
ptx_kernels.emplace_back(name, name);
|
||||
}
|
||||
@ -342,7 +342,7 @@ JitModule::JitModule(
|
||||
}
|
||||
|
||||
// If requested save them in the file cache for the next launch
|
||||
if (cache) {
|
||||
if (use_disk_cache) {
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
}
|
||||
@ -368,7 +368,9 @@ CUfunction JitModule::get_kernel(
|
||||
// If it is the first time we run this kernel then configure it. Do it only
|
||||
// once!
|
||||
if (!it->second.second) {
|
||||
configure_kernel(it->second.first);
|
||||
if (configure_kernel) {
|
||||
configure_kernel(it->second.first);
|
||||
}
|
||||
it->second.second = true;
|
||||
}
|
||||
|
||||
|
@ -64,8 +64,8 @@ struct KernelArgs {
|
||||
private:
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
// temporary values until the kernel is launched.
|
||||
// The cuGraphAddKernelNode API requires passing pointers to arguments so
|
||||
// store temporary values until the node is created.
|
||||
using Arg = std::variant<
|
||||
std::monostate,
|
||||
CUdeviceptr,
|
||||
@ -93,10 +93,7 @@ class JitModule {
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
CUfunction get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel);
|
||||
CUfunction get_kernel(const std::string& kernel_name) {
|
||||
return get_kernel(kernel_name, [](auto k) {});
|
||||
}
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
@ -109,6 +106,6 @@ JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder,
|
||||
bool cache = true);
|
||||
bool use_disk_cache = true);
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -308,7 +308,7 @@ class CustomKernel : public Primitive {
|
||||
shape_infos_(std::move(shape_infos)),
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
init_value_(init_value),
|
||||
scalar_arguments_(scalar_arguments),
|
||||
scalar_arguments_(std::move(scalar_arguments)),
|
||||
is_precompiled_(is_precompiled),
|
||||
shared_memory_(shared_memory) {}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user