mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 07:58:41 +08:00
comments
This commit is contained in:
parent
fa56bf2feb
commit
d6b204b528
@ -49,7 +49,7 @@ std::string template_arguments_hash(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string build_kernel(
|
std::string build_kernel(
|
||||||
std::string func_name,
|
const std::string& func_name,
|
||||||
const std::string& header,
|
const std::string& header,
|
||||||
const std::string& source,
|
const std::string& source,
|
||||||
const std::vector<std::string>& input_names,
|
const std::vector<std::string>& input_names,
|
||||||
@ -316,7 +316,7 @@ void CustomKernel::eval_gpu(
|
|||||||
name_,
|
name_,
|
||||||
[&]() {
|
[&]() {
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
is_precompiled_, source_, std::vector<std::string>{kernel_name});
|
is_precompiled_, source_, std::vector{kernel_name});
|
||||||
},
|
},
|
||||||
false);
|
false);
|
||||||
|
|
||||||
|
@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() {
|
|||||||
bool read_cached_ptx(
|
bool read_cached_ptx(
|
||||||
const std::filesystem::path& cache_dir,
|
const std::filesystem::path& cache_dir,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
std::vector<char>* ptx,
|
std::string& ptx,
|
||||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||||
if (cache_dir.empty()) {
|
if (cache_dir.empty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -117,15 +117,15 @@ bool read_cached_ptx(
|
|||||||
if (!ptx_file.good()) {
|
if (!ptx_file.good()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ptx->resize(ptx_size);
|
ptx.resize(ptx_size);
|
||||||
ptx_file.read(ptx->data(), ptx_size);
|
ptx_file.read(ptx.data(), ptx_size);
|
||||||
|
|
||||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||||
std::string line;
|
std::string line;
|
||||||
while (std::getline(txt_file, line)) {
|
while (std::getline(txt_file, line)) {
|
||||||
auto tab = line.find('\t');
|
auto tab = line.find('\t');
|
||||||
if (tab != std::string::npos) {
|
if (tab != std::string::npos) {
|
||||||
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -135,7 +135,7 @@ bool read_cached_ptx(
|
|||||||
void write_cached_ptx(
|
void write_cached_ptx(
|
||||||
const std::filesystem::path& cache_dir,
|
const std::filesystem::path& cache_dir,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::string& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||||
const std::string& source_code) {
|
const std::string& source_code) {
|
||||||
if (cache_dir.empty()) {
|
if (cache_dir.empty()) {
|
||||||
@ -222,7 +222,7 @@ void compile(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::string& source,
|
const std::string& source,
|
||||||
const std::vector<std::string>& kernel_names,
|
const std::vector<std::string>& kernel_names,
|
||||||
std::vector<char>& ptx,
|
std::string& ptx,
|
||||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||||
// Create the program
|
// Create the program
|
||||||
nvrtcProgram prog;
|
nvrtcProgram prog;
|
||||||
@ -282,7 +282,7 @@ void compile(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||||
}
|
}
|
||||||
ptx.resize(ptx_size, 0);
|
ptx.resize(ptx_size);
|
||||||
if (use_sass) {
|
if (use_sass) {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||||
} else {
|
} else {
|
||||||
@ -292,7 +292,7 @@ void compile(
|
|||||||
|
|
||||||
void load_module(
|
void load_module(
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::string& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||||
CUmodule& module_,
|
CUmodule& module_,
|
||||||
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
||||||
@ -322,18 +322,18 @@ JitModule::JitModule(
|
|||||||
Device& device,
|
Device& device,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const KernelBuilder& builder,
|
const KernelBuilder& builder,
|
||||||
bool cache) {
|
bool use_disk_cache) {
|
||||||
// Will hold the actual device executable source code and kernel names
|
// Will hold the actual device executable source code and kernel names
|
||||||
std::vector<char> ptx;
|
std::string ptx;
|
||||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||||
|
|
||||||
// Try to load them from the file cache
|
// Try to load them from the file cache
|
||||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
|
||||||
auto [precompiled, source_code, kernel_names] = builder();
|
auto [precompiled, source_code, kernel_names] = builder();
|
||||||
|
|
||||||
// Get the PTX or cubin
|
// Get the PTX or cubin
|
||||||
if (precompiled) {
|
if (precompiled) {
|
||||||
ptx.insert(ptx.begin(), source_code.begin(), source_code.end());
|
ptx = std::move(source_code);
|
||||||
for (auto& name : kernel_names) {
|
for (auto& name : kernel_names) {
|
||||||
ptx_kernels.emplace_back(name, name);
|
ptx_kernels.emplace_back(name, name);
|
||||||
}
|
}
|
||||||
@ -342,7 +342,7 @@ JitModule::JitModule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If requested save them in the file cache for the next launch
|
// If requested save them in the file cache for the next launch
|
||||||
if (cache) {
|
if (use_disk_cache) {
|
||||||
write_cached_ptx(
|
write_cached_ptx(
|
||||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||||
}
|
}
|
||||||
@ -368,7 +368,9 @@ CUfunction JitModule::get_kernel(
|
|||||||
// If it is the first time we run this kernel then configure it. Do it only
|
// If it is the first time we run this kernel then configure it. Do it only
|
||||||
// once!
|
// once!
|
||||||
if (!it->second.second) {
|
if (!it->second.second) {
|
||||||
configure_kernel(it->second.first);
|
if (configure_kernel) {
|
||||||
|
configure_kernel(it->second.first);
|
||||||
|
}
|
||||||
it->second.second = true;
|
it->second.second = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,8 +64,8 @@ struct KernelArgs {
|
|||||||
private:
|
private:
|
||||||
std::vector<void*> args_;
|
std::vector<void*> args_;
|
||||||
|
|
||||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
// The cuGraphAddKernelNode API requires passing pointers to arguments so
|
||||||
// temporary values until the kernel is launched.
|
// store temporary values until the node is created.
|
||||||
using Arg = std::variant<
|
using Arg = std::variant<
|
||||||
std::monostate,
|
std::monostate,
|
||||||
CUdeviceptr,
|
CUdeviceptr,
|
||||||
@ -93,10 +93,7 @@ class JitModule {
|
|||||||
JitModule& operator=(const JitModule&) = delete;
|
JitModule& operator=(const JitModule&) = delete;
|
||||||
CUfunction get_kernel(
|
CUfunction get_kernel(
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
std::function<void(CUfunction)> configure_kernel);
|
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||||
CUfunction get_kernel(const std::string& kernel_name) {
|
|
||||||
return get_kernel(kernel_name, [](auto k) {});
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CUmodule module_{nullptr};
|
CUmodule module_{nullptr};
|
||||||
@ -109,6 +106,6 @@ JitModule& get_jit_module(
|
|||||||
const mlx::core::Device& device,
|
const mlx::core::Device& device,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const KernelBuilder& builder,
|
const KernelBuilder& builder,
|
||||||
bool cache = true);
|
bool use_disk_cache = true);
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
@ -308,7 +308,7 @@ class CustomKernel : public Primitive {
|
|||||||
shape_infos_(std::move(shape_infos)),
|
shape_infos_(std::move(shape_infos)),
|
||||||
ensure_row_contiguous_(ensure_row_contiguous),
|
ensure_row_contiguous_(ensure_row_contiguous),
|
||||||
init_value_(init_value),
|
init_value_(init_value),
|
||||||
scalar_arguments_(scalar_arguments),
|
scalar_arguments_(std::move(scalar_arguments)),
|
||||||
is_precompiled_(is_precompiled),
|
is_precompiled_(is_precompiled),
|
||||||
shared_memory_(shared_memory) {}
|
shared_memory_(shared_memory) {}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user