mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom cuda kernel (#2517)
This commit is contained in:
committed by
GitHub
parent
f4c8888cbe
commit
e397177f6e
@@ -19,7 +19,8 @@ namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
using KernelBuilderResult = std::pair<
|
||||
using KernelBuilderResult = std::tuple<
|
||||
/* precompiled */ bool,
|
||||
/* source code */ std::string,
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
@@ -63,14 +64,16 @@ struct KernelArgs {
|
||||
private:
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
// temporary values untill kernel is launched.
|
||||
// The cuGraphAddKernelNode API requires passing pointers to arguments so
|
||||
// store temporary values until the node is created.
|
||||
using Arg = std::variant<
|
||||
std::monostate,
|
||||
CUdeviceptr,
|
||||
bool,
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
float,
|
||||
SmallVector<const void*>,
|
||||
SmallVector<int32_t>,
|
||||
SmallVector<int64_t>>;
|
||||
@@ -82,16 +85,19 @@ class JitModule {
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
const KernelBuilder& builder,
|
||||
bool cache);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
CUfunction get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
@@ -99,6 +105,7 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder);
|
||||
const KernelBuilder& builder,
|
||||
bool use_disk_cache = true);
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
Reference in New Issue
Block a user