mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
[CUDA] Add work per thread to compile (#2368)
This commit is contained in:
parent
b2273733ea
commit
6b1b8ea91b
@ -53,9 +53,10 @@ struct FusedKernelBuilder {
|
|||||||
|
|
||||||
// Build function signature.
|
// Build function signature.
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
os += "template <typename IdxT = uint32_t>\n";
|
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
} else {
|
} else {
|
||||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
os +=
|
||||||
|
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
}
|
}
|
||||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||||
for (size_t i = 0; i < params.size(); ++i) {
|
for (size_t i = 0; i < params.size(); ++i) {
|
||||||
@ -67,12 +68,46 @@ struct FusedKernelBuilder {
|
|||||||
}
|
}
|
||||||
os += ") {\n";
|
os += ") {\n";
|
||||||
|
|
||||||
// Index.
|
// Index. For non contiguous kernels we create a separate index
|
||||||
|
// variable per variable otherwise everyone uses `index`.
|
||||||
os +=
|
os +=
|
||||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
|
||||||
" if (index >= size) {\n"
|
" if (index >= size) {\n"
|
||||||
" return;\n"
|
" return;\n"
|
||||||
" }\n";
|
" }\n";
|
||||||
|
if (!contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " IdxT " + xname + "_idx = 0;\n";
|
||||||
|
}
|
||||||
|
os += " {\n";
|
||||||
|
os += " IdxT loc = index;\n";
|
||||||
|
os +=
|
||||||
|
" #pragma unroll\n"
|
||||||
|
" for (int i = NDIM - 1; i >= 0; i--) {\n";
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
|
||||||
|
"_strides[i]);\n";
|
||||||
|
}
|
||||||
|
os +=
|
||||||
|
" loc /= shape[i];\n"
|
||||||
|
" }\n"
|
||||||
|
" }\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work loop
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||||
|
|
||||||
// Read inputs.
|
// Read inputs.
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
@ -89,12 +124,9 @@ struct FusedKernelBuilder {
|
|||||||
} else if (contiguous) {
|
} else if (contiguous) {
|
||||||
value = fmt::format("{}[index]", xname);
|
value = fmt::format("{}[index]", xname);
|
||||||
} else {
|
} else {
|
||||||
std::string index = fmt::format(
|
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
|
||||||
xname);
|
|
||||||
value = fmt::format("{}[{}]", xname, index);
|
|
||||||
}
|
}
|
||||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write tape.
|
// Write tape.
|
||||||
@ -113,14 +145,30 @@ struct FusedKernelBuilder {
|
|||||||
}
|
}
|
||||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||||
}
|
}
|
||||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write output.
|
// Write output.
|
||||||
for (const auto& x : outputs) {
|
for (const auto& x : outputs) {
|
||||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// End of work loop
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" index++;\n";
|
||||||
|
if (!contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os += " }\n";
|
||||||
|
|
||||||
os += "}\n";
|
os += "}\n";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -156,15 +204,28 @@ void Compiled::eval_gpu(
|
|||||||
builder.build("_strided", false);
|
builder.build("_strided", false);
|
||||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||||
// Build kernel names.
|
// Build kernel names.
|
||||||
std::vector<std::string> kernel_names = {
|
std::vector<std::string> kernel_names;
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
|
||||||
};
|
|
||||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
|
||||||
kernel_names.push_back(fmt::format(
|
kernel_names.push_back(fmt::format(
|
||||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||||
kernel_names.push_back(
|
lib_name(),
|
||||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
work_per_thread));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
work_per_thread));
|
||||||
|
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
i,
|
||||||
|
work_per_thread));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
i,
|
||||||
|
work_per_thread));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||||
});
|
});
|
||||||
@ -207,13 +268,21 @@ void Compiled::eval_gpu(
|
|||||||
args.append<uint32_t>(outputs[0].data_size());
|
args.append<uint32_t>(outputs[0].data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Choose work per thread
|
||||||
|
int work_per_thread = 4;
|
||||||
|
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||||
|
work_per_thread = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Launch kernel.
|
// Launch kernel.
|
||||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
kernel_name +=
|
||||||
|
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||||
} else {
|
} else {
|
||||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
kernel_name += fmt::format(
|
||||||
|
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
|
||||||
}
|
}
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
@ -224,7 +293,8 @@ void Compiled::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, outputs[0], large, work_per_thread);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,7 +121,8 @@ void write_cached_ptx(
|
|||||||
const std::filesystem::path& cache_dir,
|
const std::filesystem::path& cache_dir,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::vector<char>& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||||
|
const std::string& source_code) {
|
||||||
if (cache_dir.empty()) {
|
if (cache_dir.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -134,6 +135,9 @@ void write_cached_ptx(
|
|||||||
for (const auto& [name, mangled] : ptx_kernels) {
|
for (const auto& [name, mangled] : ptx_kernels) {
|
||||||
txt_file << name << "\t" << mangled << std::endl;
|
txt_file << name << "\t" << mangled << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::ofstream source_file(cache_dir / (module_name + ".cu"));
|
||||||
|
source_file << source_code;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return if |device|'s version is not newer than |major|.|minor| version.
|
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||||
@ -272,7 +276,8 @@ JitModule::JitModule(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
}
|
}
|
||||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
write_cached_ptx(
|
||||||
|
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load module.
|
// Load module.
|
||||||
|
Loading…
Reference in New Issue
Block a user