[CUDA] Vectorize generated kernels (#2444)

This commit is contained in:
Angelos Katharopoulos 2025-07-31 18:18:57 -07:00 committed by GitHub
parent b26d88591c
commit 86258f292f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -104,10 +104,41 @@ struct FusedKernelBuilder {
" }\n"; " }\n";
} }
// Vectorized read loop
if (contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
if (is_scalar(x) || is_constant(i)) {
continue;
}
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
xname,
type);
}
}
// Create some space for the outputs
for (const auto& x : outputs) {
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
}
// Work loop // Work loop
if (!contiguous) {
os += os +=
"\n" "\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n"; " for (int i = 0; i < work_per_thread && index < size; i++) {\n";
} else {
os +=
"\n"
" #pragma unroll\n"
" for (int i = 0; i < work_per_thread; i++) {\n";
}
// Read inputs. // Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
@ -122,7 +153,7 @@ struct FusedKernelBuilder {
} else if (is_scalar(x)) { } else if (is_scalar(x)) {
value = fmt::format("{}[0]", xname); value = fmt::format("{}[0]", xname);
} else if (contiguous) { } else if (contiguous) {
value = fmt::format("{}[index]", xname); value = fmt::format("vec_{}[i]", xname);
} else { } else {
value = fmt::format("{}[{}_idx]", xname, xname); value = fmt::format("{}[{}_idx]", xname, xname);
} }
@ -150,25 +181,30 @@ struct FusedKernelBuilder {
// Write output. // Write output.
for (const auto& x : outputs) { for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
} }
// End of work loop // End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) { if (!contiguous) {
os += "\n";
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i]; const auto& x = inputs[i];
const std::string& xname = namer.get_name(x); const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) { if (is_scalar(x) || is_constant(i)) {
continue; continue;
} }
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n"; os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
} }
} }
os += " }\n"; os += " }\n";
// Store the output to global memory
for (const auto& x : outputs) {
os += fmt::format(
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
namer.get_name(x));
}
os += "}\n"; os += "}\n";
} }
}; };
@ -192,6 +228,15 @@ void Compiled::eval_gpu(
nvtx3::scoped_range r("Compiled::eval_gpu"); nvtx3::scoped_range r("Compiled::eval_gpu");
auto& s = stream(); auto& s = stream();
// Determine the work per thread for the vectorized reads/writes. We take it
// as 16 over the max itemsize for the outputs. Another heuristic could be
// over the max itemsize of all arrays.
int max_size = 1;
for (const auto& x : outputs) {
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
}
int work_per_thread = 16 / max_size;
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() { cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
// Build source code. // Build source code.
cu::FusedKernelBuilder builder{ cu::FusedKernelBuilder builder{
@ -205,7 +250,6 @@ void Compiled::eval_gpu(
builder.os += "\n} // namespace mlx::core::cu\n"; builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names. // Build kernel names.
std::vector<std::string> kernel_names; std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>", "mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(), lib_name(),
@ -214,19 +258,15 @@ void Compiled::eval_gpu(
"mlx::core::cu::{}_contiguous<int64_t, {}>", "mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(), lib_name(),
work_per_thread)); work_per_thread));
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
for (int i = 1; i <= MAX_NDIM; ++i) { for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", "mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
lib_name(),
i,
work_per_thread));
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>", "mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
lib_name(),
i,
work_per_thread));
} }
} }
return std::make_pair(std::move(builder.os), std::move(kernel_names)); return std::make_pair(std::move(builder.os), std::move(kernel_names));
}); });
@ -269,7 +309,6 @@ void Compiled::eval_gpu(
} }
// Choose work per thread // Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) { if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1; work_per_thread = 1;
} }