mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 14:59:22 +08:00
tmp
This commit is contained in:
parent
3b94e37270
commit
055c1ca929
@ -370,13 +370,13 @@ void CustomKernel::eval_gpu(
|
||||
for (const auto& t : copies) {
|
||||
encoder.add_temporary(t);
|
||||
}
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
if (shared_memory_ > 0 && shared_memory_ > 48000) {
|
||||
cuFuncSetAttribute(
|
||||
kernel,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_memory_);
|
||||
}
|
||||
auto kernel =
|
||||
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
|
||||
if (smem > 0 && smem > 48000) {
|
||||
cuFuncSetAttribute(
|
||||
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
|
||||
}
|
||||
});
|
||||
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
||||
}
|
||||
|
||||
|
@ -312,7 +312,7 @@ JitModule::JitModule(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels_[name] = kernel;
|
||||
kernels_[name] = std::make_pair(kernel, false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -327,7 +327,7 @@ JitModule::JitModule(
|
||||
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
|
||||
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
|
||||
CUresult jit_result = cuModuleLoadDataEx(
|
||||
&module_, ptx.c_str(), std::size(options), options, values);
|
||||
&module_, ptx.data(), std::size(options), options, values);
|
||||
if (jit_result != CUDA_SUCCESS) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
|
||||
@ -337,7 +337,7 @@ JitModule::JitModule(
|
||||
for (const auto& name : kernel_names) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str()));
|
||||
kernels_[name] = kernel;
|
||||
kernels_[name] = std::make_pair(kernel, false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -345,13 +345,23 @@ JitModule::~JitModule() {
|
||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
CUfunction JitModule::get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel) {
|
||||
auto it = kernels_.find(kernel_name);
|
||||
if (it == kernels_.end()) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("There is no kernel named {}.", kernel_name));
|
||||
}
|
||||
return it->second;
|
||||
|
||||
// If it is the first time we run this kernel then configure it. Do it only
|
||||
// once!
|
||||
if (!it->second.second) {
|
||||
configure_kernel(it->second.first);
|
||||
it->second.second = true;
|
||||
}
|
||||
|
||||
return it->second.first;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
|
@ -94,11 +94,16 @@ class JitModule {
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
CUfunction get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel);
|
||||
CUfunction get_kernel(const std::string& kernel_name) {
|
||||
return get_kernel(kernel_name, [](auto k) {});
|
||||
}
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
|
@ -388,7 +388,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"precompiled_custom_kernel",
|
||||
[](const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const nb::bytes compiled_source,
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<mx::Shape>& output_shapes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
@ -429,7 +429,9 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
return mx::fast::precompiled_custom_kernel(
|
||||
name,
|
||||
compiled_source,
|
||||
std::string(
|
||||
static_cast<const char*>(compiled_source.data()),
|
||||
compiled_source.size()),
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
|
Loading…
Reference in New Issue
Block a user