mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
A bit of refactoring
This commit is contained in:
parent
3938aaaf24
commit
d2ae81b413
@ -267,7 +267,8 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
return std::make_tuple(
|
||||
false, std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
|
@ -142,17 +142,17 @@ std::string build_kernel(
|
||||
|
||||
} // namespace
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
CustomKernelFunction cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header /* = "" */,
|
||||
bool ensure_row_contiguous /* = true */,
|
||||
bool atomic_outputs /* = false */) {
|
||||
const std::string& header,
|
||||
bool ensure_row_contiguous,
|
||||
int shared_memory) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
"[custom_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
@ -177,21 +177,21 @@ MetalKernelFunction metal_kernel(
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||
msg << "[custom_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||
msg << "[custom_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||
msg << "[custom_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
@ -199,7 +199,7 @@ MetalKernelFunction metal_kernel(
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::string kernel_name =
|
||||
@ -237,12 +237,12 @@ MetalKernelFunction metal_kernel(
|
||||
init_value,
|
||||
std::vector<ScalarArg>{},
|
||||
false,
|
||||
0),
|
||||
shared_memory),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> precompiled_custom_kernel(
|
||||
std::vector<array> precompiled_cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const std::vector<array>& inputs,
|
||||
@ -312,17 +312,14 @@ void CustomKernel::eval_gpu(
|
||||
// Compile the custom kernel
|
||||
std::string kernel_name =
|
||||
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
|
||||
auto get_module = [&]() -> cu::JitModule& {
|
||||
if (is_precompiled_) {
|
||||
return cu::get_jit_module(
|
||||
s.device, name_, source_, std::vector<std::string>{kernel_name});
|
||||
} else {
|
||||
return cu::get_jit_module(s.device, name_, [&]() {
|
||||
return std::make_pair(source_, std::vector<std::string>{kernel_name});
|
||||
});
|
||||
}
|
||||
};
|
||||
cu::JitModule& mod = get_module();
|
||||
cu::JitModule& mod = cu::get_jit_module(
|
||||
s.device,
|
||||
name_,
|
||||
[&]() {
|
||||
return std::make_tuple(
|
||||
is_precompiled_, source_, std::vector<std::string>{kernel_name});
|
||||
},
|
||||
false);
|
||||
|
||||
// Make the arguments
|
||||
cu::KernelArgs args;
|
||||
|
@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
|
||||
});
|
||||
|
||||
cu::KernelArgs args;
|
||||
@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
|
||||
});
|
||||
|
||||
cu::KernelArgs args;
|
||||
@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
||||
return std::make_tuple(
|
||||
false, jit_source_gather_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
||||
return std::make_tuple(
|
||||
false, jit_source_scatter_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
|
@ -217,85 +217,85 @@ constexpr const char* g_headers[] = {
|
||||
jit_source_utils,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
void compile(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder) {
|
||||
// Check cache.
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||
// Create program.
|
||||
auto [source_code, kernel_names] = builder();
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source_code.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
fmt::format("--include-path={}/include", cuda_home());
|
||||
args.push_back(cuda_include.c_str());
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size, 0);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& kernel_names,
|
||||
std::vector<char>& ptx,
|
||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
// Create the program
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
fmt::format("--include-path={}/include", cuda_home());
|
||||
args.push_back(cuda_include.c_str());
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size, 0);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
}
|
||||
|
||||
void load_module(
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
CUmodule& module_,
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
@ -312,33 +312,44 @@ JitModule::JitModule(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels_[name] = std::make_pair(kernel, false);
|
||||
kernels[name] = std::make_pair(kernel, false);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::string>& kernel_names) {
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
|
||||
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
|
||||
CUresult jit_result = cuModuleLoadDataEx(
|
||||
&module_, ptx.data(), std::size(options), options, values);
|
||||
if (jit_result != CUDA_SUCCESS) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
|
||||
const KernelBuilder& builder,
|
||||
bool cache) {
|
||||
// Will hold the actual device executable source code and kernel names
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
|
||||
// Try to load them from the file cache
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||
auto [precompiled, source_code, kernel_names] = builder();
|
||||
|
||||
// Get the PTX or cubin
|
||||
if (precompiled) {
|
||||
ptx.insert(ptx.begin(), source_code.begin(), source_code.end());
|
||||
for (auto& name : kernel_names) {
|
||||
ptx_kernels.emplace_back(name, name);
|
||||
}
|
||||
} else {
|
||||
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
|
||||
}
|
||||
|
||||
// If requested save them in the file cache for the next launch
|
||||
if (cache) {
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
}
|
||||
}
|
||||
|
||||
// Load kernels.
|
||||
for (const auto& name : kernel_names) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str()));
|
||||
kernels_[name] = std::make_pair(kernel, false);
|
||||
}
|
||||
// Load the module
|
||||
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
@ -372,25 +383,12 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder) {
|
||||
const KernelBuilder& builder,
|
||||
bool cache) {
|
||||
auto& map = get_jit_module_cache();
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::string>& kernel_names) {
|
||||
auto& map = get_jit_module_cache();
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, ptx, kernel_names)
|
||||
.first;
|
||||
it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
@ -19,7 +19,8 @@ namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
using KernelBuilderResult = std::pair<
|
||||
using KernelBuilderResult = std::tuple<
|
||||
/* precompiled */ bool,
|
||||
/* source code */ std::string,
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
@ -84,12 +85,8 @@ class JitModule {
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::string>& kernel_names);
|
||||
const KernelBuilder& builder,
|
||||
bool cache);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
@ -111,12 +108,7 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder);
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::string>& kernel_names);
|
||||
const KernelBuilder& builder,
|
||||
bool cache = true);
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -1,11 +1,47 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/fast.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
CustomKernelFunction cuda_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool,
|
||||
int) {
|
||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
||||
}
|
||||
|
||||
std::vector<array> precompiled_cuda_kernel(
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<Shape>&,
|
||||
const std::vector<Dtype>&,
|
||||
const std::vector<ScalarArg>&,
|
||||
std::tuple<int, int, int>,
|
||||
std::tuple<int, int, int>,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice) {
|
||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -24,4 +24,19 @@ device_info() {
|
||||
|
||||
} // namespace metal
|
||||
|
||||
namespace fast {
|
||||
|
||||
CustomKernelFunction metal_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool,
|
||||
bool) {
|
||||
throw std::runtime_error("[metal_kernel] No Metal back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
15
mlx/fast.h
15
mlx/fast.h
@ -69,7 +69,7 @@ array affine_dequantize(
|
||||
using TemplateArg = std::variant<int, bool, Dtype>;
|
||||
using ScalarArg = std::variant<bool, int, float>;
|
||||
|
||||
using MetalKernelFunction = std::function<std::vector<array>(
|
||||
using CustomKernelFunction = std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<Shape>&,
|
||||
const std::vector<Dtype>&,
|
||||
@ -80,7 +80,7 @@ using MetalKernelFunction = std::function<std::vector<array>(
|
||||
bool,
|
||||
StreamOrDevice)>;
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
CustomKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
@ -89,7 +89,16 @@ MetalKernelFunction metal_kernel(
|
||||
bool ensure_row_contiguous = true,
|
||||
bool atomic_outputs = false);
|
||||
|
||||
std::vector<array> precompiled_custom_kernel(
|
||||
CustomKernelFunction cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header = "",
|
||||
bool ensure_row_contiguous = true,
|
||||
int shared_memory = 0);
|
||||
|
||||
std::vector<array> precompiled_cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const std::vector<array>& inputs,
|
||||
|
@ -17,6 +17,66 @@ namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
namespace {
|
||||
|
||||
struct PyCustomKernelFunction {
|
||||
PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag)
|
||||
: kernel_(std::move(kernel)), tag_(tag) {}
|
||||
|
||||
std::vector<mx::array> operator()(
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<mx::Shape>& output_shapes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::optional<std::vector<std::pair<std::string, nb::object>>>&
|
||||
template_args_ = std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
mx::StreamOrDevice s = {}) const {
|
||||
std::vector<mx::array> inputs;
|
||||
for (const auto& value : inputs_) {
|
||||
inputs.push_back(to_array(value, std::nullopt));
|
||||
}
|
||||
std::vector<std::pair<std::string, mx::fast::TemplateArg>> template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
// Handle bool, int and dtype template args
|
||||
if (nb::isinstance<bool>(value)) {
|
||||
bool bool_val = nb::cast<bool>(value);
|
||||
template_args.emplace_back(name, bool_val);
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.emplace_back(name, int_val);
|
||||
} else if (nb::isinstance<mx::Dtype>(value)) {
|
||||
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
|
||||
template_args.emplace_back(name, dtype);
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << tag_
|
||||
<< " Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel_(
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
}
|
||||
|
||||
mx::fast::CustomKernelFunction kernel_;
|
||||
const char* tag_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void init_fast(nb::module_& parent_module) {
|
||||
auto m =
|
||||
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
||||
@ -240,53 +300,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
ensure_row_contiguous,
|
||||
atomic_outputs);
|
||||
return nb::cpp_function(
|
||||
[kernel = std::move(kernel)](
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<mx::Shape>& output_shapes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::optional<
|
||||
std::vector<std::pair<std::string, nb::object>>>&
|
||||
template_args_ = std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
std::vector<mx::array> inputs;
|
||||
for (const auto& value : inputs_) {
|
||||
inputs.push_back(to_array(value, std::nullopt));
|
||||
}
|
||||
std::vector<std::pair<std::string, mx::fast::TemplateArg>>
|
||||
template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
// Handle bool, int and dtype template args
|
||||
if (nb::isinstance<bool>(value)) {
|
||||
bool bool_val = nb::cast<bool>(value);
|
||||
template_args.emplace_back(name, bool_val);
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.emplace_back(name, int_val);
|
||||
} else if (nb::isinstance<mx::Dtype>(value)) {
|
||||
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
|
||||
template_args.emplace_back(name, dtype);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel(
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
},
|
||||
PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"),
|
||||
nb::kw_only(),
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
@ -386,7 +400,123 @@ void init_fast(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"precompiled_custom_kernel",
|
||||
"cuda_kernel",
|
||||
[](const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header,
|
||||
bool ensure_row_contiguous,
|
||||
int shared_mem) {
|
||||
auto kernel = mx::fast::cuda_kernel(
|
||||
name,
|
||||
input_names,
|
||||
output_names,
|
||||
source,
|
||||
header,
|
||||
ensure_row_contiguous,
|
||||
shared_mem);
|
||||
return nb::cpp_function(
|
||||
PyCustomKernelFunction(std::move(kernel), "[cuda_kernel]"),
|
||||
nb::kw_only(),
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
"output_dtypes"_a,
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"template"_a = nb::none(),
|
||||
"init_value"_a = nb::none(),
|
||||
"verbose"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
|
||||
Args:
|
||||
inputs (List[array]): The inputs passed to the CUDA kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
List[array]: The list of output arrays.)pbdoc");
|
||||
},
|
||||
"name"_a,
|
||||
"input_names"_a,
|
||||
"output_names"_a,
|
||||
"source"_a,
|
||||
"header"_a = "",
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"shared_memory"_a = 0,
|
||||
R"pbdoc(
|
||||
A jit-compiled custom CUDA kernel defined from a source string.
|
||||
|
||||
This is the CUDA equivalent of :ref:`custom_metal_kernels`.
|
||||
|
||||
Args:
|
||||
name (str): Name for the kernel.
|
||||
input_names (List[str]): The parameter names of the inputs in the
|
||||
function signature.
|
||||
output_names (List[str]): The parameter names of the outputs in the
|
||||
function signature.
|
||||
source (str): Source code. This is the body of a function in CUDA,
|
||||
the function signature will be automatically generated.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of
|
||||
the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
shared_memory (int): The dynamic shared memory to request for the
|
||||
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
|
||||
|
||||
Returns:
|
||||
Callable ``cuda_kernel``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = '''
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
T tmp = inp[elem];
|
||||
out[elem] = exp(tmp);
|
||||
'''
|
||||
|
||||
kernel = mx.fast.cuda_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
verbose=True,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(16, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"precompiled_cuda_kernel",
|
||||
[](const std::string& name,
|
||||
const nb::bytes compiled_source,
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
@ -420,14 +550,14 @@ void init_fast(nb::module_& parent_module) {
|
||||
std::string vtype_name =
|
||||
nb::cast<std::string>(vtype.attr("__name__"));
|
||||
std::ostringstream msg;
|
||||
msg << "[precompiled_custom_kernel] Invalid scalar argument type. "
|
||||
msg << "[precompiled_cuda_kernel] Invalid scalar argument type. "
|
||||
<< "Received " << vtype_name
|
||||
<< " but must be one of bool, int or float";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
return mx::fast::precompiled_custom_kernel(
|
||||
return mx::fast::precompiled_cuda_kernel(
|
||||
name,
|
||||
std::string(
|
||||
static_cast<const char*>(compiled_source.data()),
|
||||
@ -457,5 +587,27 @@ void init_fast(nb::module_& parent_module) {
|
||||
"ensure_row_contiguous"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Run a precompiled CUDA kernel defined from PTX or cubin.
|
||||
|
||||
This op is still experimental and various parts of the API may change.
|
||||
|
||||
Args:
|
||||
name (str): Name for the kernel
|
||||
compiled_source (bytes): The precompiled kernel in raw bytes.
|
||||
inputs (List[array]): The inputs passed to the CUDA kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output.
|
||||
scalars (List[Union[bool, int, float]]): A list of scalar arguments to
|
||||
pass to the kernel.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
shared_memory (int): The dynamic shared memory to request for the
|
||||
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user