mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 18:26:41 +08:00
tmp
This commit is contained in:
parent
3b94e37270
commit
055c1ca929
@ -370,13 +370,13 @@ void CustomKernel::eval_gpu(
|
|||||||
for (const auto& t : copies) {
|
for (const auto& t : copies) {
|
||||||
encoder.add_temporary(t);
|
encoder.add_temporary(t);
|
||||||
}
|
}
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel =
|
||||||
if (shared_memory_ > 0 && shared_memory_ > 48000) {
|
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
|
||||||
cuFuncSetAttribute(
|
if (smem > 0 && smem > 48000) {
|
||||||
kernel,
|
cuFuncSetAttribute(
|
||||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
|
||||||
shared_memory_);
|
}
|
||||||
}
|
});
|
||||||
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -312,7 +312,7 @@ JitModule::JitModule(
|
|||||||
for (const auto& [name, mangled] : ptx_kernels) {
|
for (const auto& [name, mangled] : ptx_kernels) {
|
||||||
CUfunction kernel;
|
CUfunction kernel;
|
||||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||||
kernels_[name] = kernel;
|
kernels_[name] = std::make_pair(kernel, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,7 +327,7 @@ JitModule::JitModule(
|
|||||||
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
|
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
|
||||||
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
|
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
|
||||||
CUresult jit_result = cuModuleLoadDataEx(
|
CUresult jit_result = cuModuleLoadDataEx(
|
||||||
&module_, ptx.c_str(), std::size(options), options, values);
|
&module_, ptx.data(), std::size(options), options, values);
|
||||||
if (jit_result != CUDA_SUCCESS) {
|
if (jit_result != CUDA_SUCCESS) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
|
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
|
||||||
@ -337,7 +337,7 @@ JitModule::JitModule(
|
|||||||
for (const auto& name : kernel_names) {
|
for (const auto& name : kernel_names) {
|
||||||
CUfunction kernel;
|
CUfunction kernel;
|
||||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str()));
|
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str()));
|
||||||
kernels_[name] = kernel;
|
kernels_[name] = std::make_pair(kernel, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -345,13 +345,23 @@ JitModule::~JitModule() {
|
|||||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||||
}
|
}
|
||||||
|
|
||||||
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
CUfunction JitModule::get_kernel(
|
||||||
|
const std::string& kernel_name,
|
||||||
|
std::function<void(CUfunction)> configure_kernel) {
|
||||||
auto it = kernels_.find(kernel_name);
|
auto it = kernels_.find(kernel_name);
|
||||||
if (it == kernels_.end()) {
|
if (it == kernels_.end()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
fmt::format("There is no kernel named {}.", kernel_name));
|
fmt::format("There is no kernel named {}.", kernel_name));
|
||||||
}
|
}
|
||||||
return it->second;
|
|
||||||
|
// If it is the first time we run this kernel then configure it. Do it only
|
||||||
|
// once!
|
||||||
|
if (!it->second.second) {
|
||||||
|
configure_kernel(it->second.first);
|
||||||
|
it->second.second = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return it->second.first;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||||
|
@ -94,11 +94,16 @@ class JitModule {
|
|||||||
|
|
||||||
JitModule(const JitModule&) = delete;
|
JitModule(const JitModule&) = delete;
|
||||||
JitModule& operator=(const JitModule&) = delete;
|
JitModule& operator=(const JitModule&) = delete;
|
||||||
CUfunction get_kernel(const std::string& kernel_name);
|
CUfunction get_kernel(
|
||||||
|
const std::string& kernel_name,
|
||||||
|
std::function<void(CUfunction)> configure_kernel);
|
||||||
|
CUfunction get_kernel(const std::string& kernel_name) {
|
||||||
|
return get_kernel(kernel_name, [](auto k) {});
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CUmodule module_{nullptr};
|
CUmodule module_{nullptr};
|
||||||
std::unordered_map<std::string, CUfunction> kernels_;
|
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||||
|
@ -388,7 +388,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"precompiled_custom_kernel",
|
"precompiled_custom_kernel",
|
||||||
[](const std::string& name,
|
[](const std::string& name,
|
||||||
const std::string& compiled_source,
|
const nb::bytes compiled_source,
|
||||||
const std::vector<ScalarOrArray>& inputs_,
|
const std::vector<ScalarOrArray>& inputs_,
|
||||||
const std::vector<mx::Shape>& output_shapes,
|
const std::vector<mx::Shape>& output_shapes,
|
||||||
const std::vector<mx::Dtype>& output_dtypes,
|
const std::vector<mx::Dtype>& output_dtypes,
|
||||||
@ -429,7 +429,9 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
return mx::fast::precompiled_custom_kernel(
|
return mx::fast::precompiled_custom_kernel(
|
||||||
name,
|
name,
|
||||||
compiled_source,
|
std::string(
|
||||||
|
static_cast<const char*>(compiled_source.data()),
|
||||||
|
compiled_source.size()),
|
||||||
inputs,
|
inputs,
|
||||||
output_shapes,
|
output_shapes,
|
||||||
output_dtypes,
|
output_dtypes,
|
||||||
|
Loading…
Reference in New Issue
Block a user