mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 10:46:39 +08:00
[CUDA] Vectorize generated kernels (#2444)
This commit is contained in:
parent
b26d88591c
commit
86258f292f
@ -104,10 +104,41 @@ struct FusedKernelBuilder {
|
||||
" }\n";
|
||||
}
|
||||
|
||||
// Vectorized read loop
|
||||
if (contiguous) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
os += fmt::format(
|
||||
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
|
||||
xname,
|
||||
type);
|
||||
}
|
||||
}
|
||||
|
||||
// Create some space for the outputs
|
||||
for (const auto& x : outputs) {
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
os += fmt::format(
|
||||
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
|
||||
}
|
||||
|
||||
// Work loop
|
||||
os +=
|
||||
"\n"
|
||||
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||
if (!contiguous) {
|
||||
os +=
|
||||
"\n"
|
||||
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||
} else {
|
||||
os +=
|
||||
"\n"
|
||||
" #pragma unroll\n"
|
||||
" for (int i = 0; i < work_per_thread; i++) {\n";
|
||||
}
|
||||
|
||||
// Read inputs.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
@ -122,7 +153,7 @@ struct FusedKernelBuilder {
|
||||
} else if (is_scalar(x)) {
|
||||
value = fmt::format("{}[0]", xname);
|
||||
} else if (contiguous) {
|
||||
value = fmt::format("{}[index]", xname);
|
||||
value = fmt::format("vec_{}[i]", xname);
|
||||
} else {
|
||||
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||
}
|
||||
@ -150,25 +181,30 @@ struct FusedKernelBuilder {
|
||||
|
||||
// Write output.
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
|
||||
// End of work loop
|
||||
os +=
|
||||
"\n"
|
||||
" index++;\n";
|
||||
if (!contiguous) {
|
||||
os += "\n";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
|
||||
os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
|
||||
}
|
||||
}
|
||||
os += " }\n";
|
||||
|
||||
// Store the output to global memory
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(
|
||||
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
|
||||
namer.get_name(x));
|
||||
}
|
||||
|
||||
os += "}\n";
|
||||
}
|
||||
};
|
||||
@ -192,6 +228,15 @@ void Compiled::eval_gpu(
|
||||
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||
auto& s = stream();
|
||||
|
||||
// Determine the work per thread for the vectorized reads/writes. We take it
|
||||
// as 16 over the max itemsize for the outputs. Another heuristic could be
|
||||
// over the max itemsize of all arrays.
|
||||
int max_size = 1;
|
||||
for (const auto& x : outputs) {
|
||||
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
|
||||
}
|
||||
int work_per_thread = 16 / max_size;
|
||||
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||
// Build source code.
|
||||
cu::FusedKernelBuilder builder{
|
||||
@ -205,28 +250,23 @@ void Compiled::eval_gpu(
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names;
|
||||
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
|
||||
lib_name(),
|
||||
i,
|
||||
work_per_thread));
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
|
||||
lib_name(),
|
||||
i,
|
||||
work_per_thread));
|
||||
"mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
@ -269,7 +309,6 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
// Choose work per thread
|
||||
int work_per_thread = 4;
|
||||
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user