mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom cuda kernel (#2517)
This commit is contained in:
committed by
GitHub
parent
f4c8888cbe
commit
e397177f6e
@@ -20,6 +20,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
|
||||
@@ -267,7 +267,8 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
return std::make_tuple(
|
||||
false, std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
|
||||
379
mlx/backend/cuda/custom_kernel.cpp
Normal file
379
mlx/backend/cuda/custom_kernel.cpp
Normal file
@@ -0,0 +1,379 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char* default_header = R"(
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||
|
||||
)";
|
||||
|
||||
std::string template_arguments_hash(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
if (template_args.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string hash;
|
||||
hash.reserve(512);
|
||||
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
hash += fmt::format("_{}", std::get<int>(arg));
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
hash += (std::get<bool>(arg)) ? "_t" : "_f";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
hash += "_";
|
||||
hash += get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
std::string build_kernel(
|
||||
const std::string& func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 8192);
|
||||
kernel_source += default_header;
|
||||
kernel_source += header;
|
||||
kernel_source +=
|
||||
"namespace mlx::core::cu {\n\n"
|
||||
"namespace cg = cooperative_groups;\n\n";
|
||||
|
||||
kernel_source += "__global__ void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
kernel_source += " const ";
|
||||
kernel_source += dtype_to_cuda_type(arr.dtype());
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += ",\n";
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source += " const __grid_constant__ Shape ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_shape,\n";
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source += " const __grid_constant__ Strides ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_strides,\n";
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source += " const __grid_constant__ int ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_ndim,\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype_to_cuda_type(dtype);
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
if (i < output_names.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Set compile time constants
|
||||
if (!template_args.empty()) {
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
kernel_source +=
|
||||
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
kernel_source += fmt::format(
|
||||
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
|
||||
} else {
|
||||
kernel_source += fmt::format(
|
||||
" using {} = {};\n",
|
||||
name,
|
||||
dtype_to_cuda_type(std::get<Dtype>(arg)));
|
||||
}
|
||||
}
|
||||
kernel_source += "\n";
|
||||
}
|
||||
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
|
||||
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CustomKernelFunction cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header,
|
||||
bool ensure_row_contiguous,
|
||||
int shared_memory) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[custom_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
|
||||
return [=, shape_infos = std::move(shape_infos)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[custom_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[custom_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[custom_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::string kernel_name =
|
||||
"custom_kernel_" + name + template_arguments_hash(template_args);
|
||||
std::string kernel_source = build_kernel(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos);
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << kernel_name
|
||||
<< "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value,
|
||||
std::vector<ScalarArg>{},
|
||||
false,
|
||||
shared_memory),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> precompiled_cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<ScalarArg>& scalars,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice s) {
|
||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
to_stream(s),
|
||||
name,
|
||||
compiled_source,
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value,
|
||||
scalars,
|
||||
true,
|
||||
shared_memory),
|
||||
inputs);
|
||||
}
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("CustomKernel::eval_gpu");
|
||||
auto& s = stream();
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
// Allocate and initialize the output arrays
|
||||
for (auto& out : outputs) {
|
||||
if (init_value_) {
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
// Create the input arrays and copy if needed
|
||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (!ensure_row_contiguous_ || no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
std::vector<array> checked_inputs;
|
||||
for (const array& in : inputs) {
|
||||
checked_inputs.push_back(check_input(in));
|
||||
}
|
||||
|
||||
// Compile the custom kernel
|
||||
std::string kernel_name =
|
||||
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
|
||||
cu::JitModule& mod = cu::get_jit_module(
|
||||
s.device,
|
||||
name_,
|
||||
[&]() {
|
||||
return std::make_tuple(
|
||||
is_precompiled_, source_, std::vector{kernel_name});
|
||||
},
|
||||
false);
|
||||
|
||||
// Make the arguments
|
||||
cu::KernelArgs args;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
args.append(in);
|
||||
if (shape_info.shape) {
|
||||
args.append_ndim(in.shape());
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
args.append_ndim(in.strides());
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
args.append<int32_t>(in.ndim());
|
||||
}
|
||||
}
|
||||
for (auto& out : outputs) {
|
||||
args.append(out);
|
||||
}
|
||||
for (auto& s : scalar_arguments_) {
|
||||
if (std::holds_alternative<bool>(s)) {
|
||||
args.append(std::get<bool>(s));
|
||||
} else if (std::holds_alternative<int>(s)) {
|
||||
args.append(std::get<int>(s));
|
||||
} else if (std::holds_alternative<float>(s)) {
|
||||
args.append(std::get<float>(s));
|
||||
}
|
||||
}
|
||||
|
||||
// Make the grid
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
||||
|
||||
// Call the kernel
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : checked_inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
for (const auto& t : copies) {
|
||||
encoder.add_temporary(t);
|
||||
}
|
||||
auto kernel =
|
||||
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
|
||||
if (smem > 0 && smem > 48000) {
|
||||
cuFuncSetAttribute(
|
||||
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
|
||||
}
|
||||
});
|
||||
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
@@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
|
||||
});
|
||||
|
||||
cu::KernelArgs args;
|
||||
@@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
|
||||
});
|
||||
|
||||
cu::KernelArgs args;
|
||||
@@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
||||
return std::make_tuple(
|
||||
false, jit_source_gather_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
@@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
||||
return std::make_tuple(
|
||||
false, jit_source_scatter_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
|
||||
@@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() {
|
||||
bool read_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
std::vector<char>* ptx,
|
||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||
std::string& ptx,
|
||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return false;
|
||||
}
|
||||
@@ -117,15 +117,15 @@ bool read_cached_ptx(
|
||||
if (!ptx_file.good()) {
|
||||
return false;
|
||||
}
|
||||
ptx->resize(ptx_size);
|
||||
ptx_file.read(ptx->data(), ptx_size);
|
||||
ptx.resize(ptx_size);
|
||||
ptx_file.read(ptx.data(), ptx_size);
|
||||
|
||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::string line;
|
||||
while (std::getline(txt_file, line)) {
|
||||
auto tab = line.find('\t');
|
||||
if (tab != std::string::npos) {
|
||||
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@@ -135,7 +135,7 @@ bool read_cached_ptx(
|
||||
void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
const std::string& source_code) {
|
||||
if (cache_dir.empty()) {
|
||||
@@ -217,85 +217,85 @@ constexpr const char* g_headers[] = {
|
||||
jit_source_utils,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
void compile(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder) {
|
||||
// Check cache.
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||
// Create program.
|
||||
auto [source_code, kernel_names] = builder();
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source_code.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
fmt::format("--include-path={}/include", cuda_home());
|
||||
args.push_back(cuda_include.c_str());
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size, 0);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& kernel_names,
|
||||
std::string& ptx,
|
||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
// Create the program
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
fmt::format("--include-path={}/include", cuda_home());
|
||||
args.push_back(cuda_include.c_str());
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
}
|
||||
|
||||
void load_module(
|
||||
const std::string& module_name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
CUmodule& module_,
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
@@ -312,21 +312,69 @@ JitModule::JitModule(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels_[name] = kernel;
|
||||
kernels[name] = std::make_pair(kernel, false);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder,
|
||||
bool use_disk_cache) {
|
||||
// Will hold the actual device executable source code and kernel names
|
||||
std::string ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
|
||||
// Try to load them from the file cache
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
|
||||
auto [precompiled, source_code, kernel_names] = builder();
|
||||
|
||||
// Get the PTX or cubin
|
||||
if (precompiled) {
|
||||
ptx = std::move(source_code);
|
||||
for (auto& name : kernel_names) {
|
||||
ptx_kernels.emplace_back(name, name);
|
||||
}
|
||||
} else {
|
||||
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
|
||||
}
|
||||
|
||||
// If requested save them in the file cache for the next launch
|
||||
if (use_disk_cache) {
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the module
|
||||
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
CUfunction JitModule::get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel) {
|
||||
auto it = kernels_.find(kernel_name);
|
||||
if (it == kernels_.end()) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("There is no kernel named {}.", kernel_name));
|
||||
}
|
||||
return it->second;
|
||||
|
||||
// If it is the first time we run this kernel then configure it. Do it only
|
||||
// once!
|
||||
if (!it->second.second) {
|
||||
if (configure_kernel) {
|
||||
configure_kernel(it->second.first);
|
||||
}
|
||||
it->second.second = true;
|
||||
}
|
||||
|
||||
return it->second.first;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
@@ -337,11 +385,12 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder) {
|
||||
const KernelBuilder& builder,
|
||||
bool cache) {
|
||||
auto& map = get_jit_module_cache();
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||
it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
@@ -19,7 +19,8 @@ namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
using KernelBuilderResult = std::pair<
|
||||
using KernelBuilderResult = std::tuple<
|
||||
/* precompiled */ bool,
|
||||
/* source code */ std::string,
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
@@ -63,14 +64,16 @@ struct KernelArgs {
|
||||
private:
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
// temporary values untill kernel is launched.
|
||||
// The cuGraphAddKernelNode API requires passing pointers to arguments so
|
||||
// store temporary values until the node is created.
|
||||
using Arg = std::variant<
|
||||
std::monostate,
|
||||
CUdeviceptr,
|
||||
bool,
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
float,
|
||||
SmallVector<const void*>,
|
||||
SmallVector<int32_t>,
|
||||
SmallVector<int64_t>>;
|
||||
@@ -82,16 +85,19 @@ class JitModule {
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
const KernelBuilder& builder,
|
||||
bool cache);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
CUfunction get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
@@ -99,6 +105,7 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder);
|
||||
const KernelBuilder& builder,
|
||||
bool use_disk_cache = true);
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -1,11 +1,47 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/fast.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
CustomKernelFunction cuda_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool,
|
||||
int) {
|
||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
||||
}
|
||||
|
||||
std::vector<array> precompiled_cuda_kernel(
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<Shape>&,
|
||||
const std::vector<Dtype>&,
|
||||
const std::vector<ScalarArg>&,
|
||||
std::tuple<int, int, int>,
|
||||
std::tuple<int, int, int>,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice) {
|
||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -41,10 +41,6 @@ NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
|
||||
Reference in New Issue
Block a user