This commit is contained in:
Angelos Katharopoulos 2025-08-20 00:28:28 -07:00
parent fa56bf2feb
commit d6b204b528
4 changed files with 24 additions and 25 deletions

View File

@ -49,7 +49,7 @@ std::string template_arguments_hash(
}
std::string build_kernel(
std::string func_name,
const std::string& func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
@ -316,7 +316,7 @@ void CustomKernel::eval_gpu(
name_,
[&]() {
return std::make_tuple(
is_precompiled_, source_, std::vector<std::string>{kernel_name});
is_precompiled_, source_, std::vector{kernel_name});
},
false);

View File

@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() {
bool read_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
std::vector<char>* ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
std::string& ptx,
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) {
return false;
}
@ -117,15 +117,15 @@ bool read_cached_ptx(
if (!ptx_file.good()) {
return false;
}
ptx->resize(ptx_size);
ptx_file.read(ptx->data(), ptx_size);
ptx.resize(ptx_size);
ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::string line;
while (std::getline(txt_file, line)) {
auto tab = line.find('\t');
if (tab != std::string::npos) {
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
}
}
return true;
@ -135,7 +135,7 @@ bool read_cached_ptx(
void write_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
const std::vector<char>& ptx,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
if (cache_dir.empty()) {
@ -222,7 +222,7 @@ void compile(
const std::string& module_name,
const std::string& source,
const std::vector<std::string>& kernel_names,
std::vector<char>& ptx,
std::string& ptx,
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
// Create the program
nvrtcProgram prog;
@ -282,7 +282,7 @@ void compile(
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
}
ptx.resize(ptx_size, 0);
ptx.resize(ptx_size);
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else {
@ -292,7 +292,7 @@ void compile(
void load_module(
const std::string& module_name,
const std::vector<char>& ptx,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
@ -322,18 +322,18 @@ JitModule::JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder,
bool cache) {
bool use_disk_cache) {
// Will hold the actual device executable source code and kernel names
std::vector<char> ptx;
std::string ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
// Try to load them from the file cache
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
auto [precompiled, source_code, kernel_names] = builder();
// Get the PTX or cubin
if (precompiled) {
ptx.insert(ptx.begin(), source_code.begin(), source_code.end());
ptx = std::move(source_code);
for (auto& name : kernel_names) {
ptx_kernels.emplace_back(name, name);
}
@ -342,7 +342,7 @@ JitModule::JitModule(
}
// If requested save them in the file cache for the next launch
if (cache) {
if (use_disk_cache) {
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
@ -368,7 +368,9 @@ CUfunction JitModule::get_kernel(
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
configure_kernel(it->second.first);
if (configure_kernel) {
configure_kernel(it->second.first);
}
it->second.second = true;
}

View File

@ -64,8 +64,8 @@ struct KernelArgs {
private:
std::vector<void*> args_;
// The cuLaunchKernel API requires passing pointers to arguments so store
// temporary values until the kernel is launched.
// The cuGraphAddKernelNode API requires passing pointers to arguments so
// store temporary values until the node is created.
using Arg = std::variant<
std::monostate,
CUdeviceptr,
@ -93,10 +93,7 @@ class JitModule {
JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel);
CUfunction get_kernel(const std::string& kernel_name) {
return get_kernel(kernel_name, [](auto k) {});
}
std::function<void(CUfunction)> configure_kernel = nullptr);
private:
CUmodule module_{nullptr};
@ -109,6 +106,6 @@ JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const KernelBuilder& builder,
bool cache = true);
bool use_disk_cache = true);
} // namespace mlx::core::cu

View File

@ -308,7 +308,7 @@ class CustomKernel : public Primitive {
shape_infos_(std::move(shape_infos)),
ensure_row_contiguous_(ensure_row_contiguous),
init_value_(init_value),
scalar_arguments_(scalar_arguments),
scalar_arguments_(std::move(scalar_arguments)),
is_precompiled_(is_precompiled),
shared_memory_(shared_memory) {}