From d6b204b5282405bcae8a3c3cc2a1dbdc1c4273cb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 20 Aug 2025 00:28:28 -0700 Subject: [PATCH] comments --- mlx/backend/cuda/custom_kernel.cpp | 4 ++-- mlx/backend/cuda/jit_module.cpp | 32 ++++++++++++++++-------------- mlx/backend/cuda/jit_module.h | 11 ++++------ mlx/fast_primitives.h | 2 +- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 704567ffa..ee1778fd8 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -49,7 +49,7 @@ std::string template_arguments_hash( } std::string build_kernel( - std::string func_name, + const std::string& func_name, const std::string& header, const std::string& source, const std::vector& input_names, @@ -316,7 +316,7 @@ void CustomKernel::eval_gpu( name_, [&]() { return std::make_tuple( - is_precompiled_, source_, std::vector{kernel_name}); + is_precompiled_, source_, std::vector{kernel_name}); }, false); diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index b188a7f0e..531052d46 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() { bool read_cached_ptx( const std::filesystem::path& cache_dir, const std::string& module_name, - std::vector* ptx, - std::vector>* ptx_kernels) { + std::string& ptx, + std::vector>& ptx_kernels) { if (cache_dir.empty()) { return false; } @@ -117,15 +117,15 @@ bool read_cached_ptx( if (!ptx_file.good()) { return false; } - ptx->resize(ptx_size); - ptx_file.read(ptx->data(), ptx_size); + ptx.resize(ptx_size); + ptx_file.read(ptx.data(), ptx_size); std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); std::string line; while (std::getline(txt_file, line)) { auto tab = line.find('\t'); if (tab != std::string::npos) { - ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1)); + ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1)); } } return true; @@ -135,7 +135,7 @@ bool read_cached_ptx( void write_cached_ptx( const std::filesystem::path& cache_dir, const std::string& module_name, - const std::vector& ptx, + const std::string& ptx, const std::vector>& ptx_kernels, const std::string& source_code) { if (cache_dir.empty()) { @@ -222,7 +222,7 @@ void compile( const std::string& module_name, const std::string& source, const std::vector& kernel_names, - std::vector& ptx, + std::string& ptx, std::vector>& ptx_kernels) { // Create the program nvrtcProgram prog; @@ -282,7 +282,7 @@ void compile( } else { CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); } - ptx.resize(ptx_size, 0); + ptx.resize(ptx_size); if (use_sass) { CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); } else { @@ -292,7 +292,7 @@ void compile( void load_module( const std::string& module_name, - const std::vector& ptx, + const std::string& ptx, const std::vector>& ptx_kernels, CUmodule& module_, std::unordered_map>& kernels) { @@ -322,18 +322,18 @@ JitModule::JitModule( Device& device, const std::string& module_name, const KernelBuilder& builder, - bool cache) { + bool use_disk_cache) { // Will hold the actual device executable source code and kernel names - std::vector ptx; + std::string ptx; std::vector> ptx_kernels; // Try to load them from the file cache - if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { + if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) { auto [precompiled, source_code, kernel_names] = builder(); // Get the PTX or cubin if (precompiled) { - ptx.insert(ptx.begin(), source_code.begin(), source_code.end()); + ptx = std::move(source_code); for (auto& name : kernel_names) { ptx_kernels.emplace_back(name, name); } @@ -342,7 +342,7 @@ JitModule::JitModule( } // If requested save them in the file cache for the next launch - if (cache) { + if (use_disk_cache) { write_cached_ptx( ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code); } @@ -368,7 +368,9 @@ CUfunction JitModule::get_kernel( // If it is the first time we run this kernel then configure it. Do it only // once! if (!it->second.second) { - configure_kernel(it->second.first); + if (configure_kernel) { + configure_kernel(it->second.first); + } it->second.second = true; } diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index a4a75bb96..d919f9bc0 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -64,8 +64,8 @@ struct KernelArgs { private: std::vector args_; - // The cuLaunchKernel API requires passing pointers to arguments so store - // temporary values until the kernel is launched. + // The cuGraphAddKernelNode API requires passing pointers to arguments so + // store temporary values until the node is created. using Arg = std::variant< std::monostate, CUdeviceptr, @@ -93,10 +93,7 @@ class JitModule { JitModule& operator=(const JitModule&) = delete; CUfunction get_kernel( const std::string& kernel_name, - std::function configure_kernel); - CUfunction get_kernel(const std::string& kernel_name) { - return get_kernel(kernel_name, [](auto k) {}); - } + std::function configure_kernel = nullptr); private: CUmodule module_{nullptr}; @@ -109,6 +106,6 @@ JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, const KernelBuilder& builder, - bool cache = true); + bool use_disk_cache = true); } // namespace mlx::core::cu diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 98730104d..e0e83f726 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -308,7 +308,7 @@ class CustomKernel : public Primitive { shape_infos_(std::move(shape_infos)), ensure_row_contiguous_(ensure_row_contiguous), init_value_(init_value), - scalar_arguments_(scalar_arguments), + scalar_arguments_(std::move(scalar_arguments)), is_precompiled_(is_precompiled), shared_memory_(shared_memory) {}