mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 10:46:39 +08:00
[CUDA] Vectorize generated kernels (#2444)
This commit is contained in:
parent
b26d88591c
commit
86258f292f
@ -104,10 +104,41 @@ struct FusedKernelBuilder {
|
|||||||
" }\n";
|
" }\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Vectorized read loop
|
||||||
|
if (contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
os += fmt::format(
|
||||||
|
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
|
||||||
|
xname,
|
||||||
|
type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create some space for the outputs
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
os += fmt::format(
|
||||||
|
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
|
||||||
|
}
|
||||||
|
|
||||||
// Work loop
|
// Work loop
|
||||||
os +=
|
if (!contiguous) {
|
||||||
"\n"
|
os +=
|
||||||
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
"\n"
|
||||||
|
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||||
|
} else {
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" #pragma unroll\n"
|
||||||
|
" for (int i = 0; i < work_per_thread; i++) {\n";
|
||||||
|
}
|
||||||
|
|
||||||
// Read inputs.
|
// Read inputs.
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
@ -122,7 +153,7 @@ struct FusedKernelBuilder {
|
|||||||
} else if (is_scalar(x)) {
|
} else if (is_scalar(x)) {
|
||||||
value = fmt::format("{}[0]", xname);
|
value = fmt::format("{}[0]", xname);
|
||||||
} else if (contiguous) {
|
} else if (contiguous) {
|
||||||
value = fmt::format("{}[index]", xname);
|
value = fmt::format("vec_{}[i]", xname);
|
||||||
} else {
|
} else {
|
||||||
value = fmt::format("{}[{}_idx]", xname, xname);
|
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||||
}
|
}
|
||||||
@ -150,25 +181,30 @@ struct FusedKernelBuilder {
|
|||||||
|
|
||||||
// Write output.
|
// Write output.
|
||||||
for (const auto& x : outputs) {
|
for (const auto& x : outputs) {
|
||||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
// End of work loop
|
// End of work loop
|
||||||
os +=
|
|
||||||
"\n"
|
|
||||||
" index++;\n";
|
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
|
os += "\n";
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
const auto& x = inputs[i];
|
const auto& x = inputs[i];
|
||||||
const std::string& xname = namer.get_name(x);
|
const std::string& xname = namer.get_name(x);
|
||||||
if (is_scalar(x) || is_constant(i)) {
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
|
os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
os += " }\n";
|
os += " }\n";
|
||||||
|
|
||||||
|
// Store the output to global memory
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
os += fmt::format(
|
||||||
|
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
|
||||||
|
namer.get_name(x));
|
||||||
|
}
|
||||||
|
|
||||||
os += "}\n";
|
os += "}\n";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -192,6 +228,15 @@ void Compiled::eval_gpu(
|
|||||||
nvtx3::scoped_range r("Compiled::eval_gpu");
|
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
|
||||||
|
// Determine the work per thread for the vectorized reads/writes. We take it
|
||||||
|
// as 16 over the max itemsize for the outputs. Another heuristic could be
|
||||||
|
// over the max itemsize of all arrays.
|
||||||
|
int max_size = 1;
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
|
||||||
|
}
|
||||||
|
int work_per_thread = 16 / max_size;
|
||||||
|
|
||||||
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||||
// Build source code.
|
// Build source code.
|
||||||
cu::FusedKernelBuilder builder{
|
cu::FusedKernelBuilder builder{
|
||||||
@ -205,28 +250,23 @@ void Compiled::eval_gpu(
|
|||||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||||
// Build kernel names.
|
// Build kernel names.
|
||||||
std::vector<std::string> kernel_names;
|
std::vector<std::string> kernel_names;
|
||||||
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
kernel_names.push_back(fmt::format(
|
||||||
kernel_names.push_back(fmt::format(
|
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
lib_name(),
|
||||||
lib_name(),
|
work_per_thread));
|
||||||
work_per_thread));
|
kernel_names.push_back(fmt::format(
|
||||||
kernel_names.push_back(fmt::format(
|
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||||
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
lib_name(),
|
||||||
lib_name(),
|
work_per_thread));
|
||||||
work_per_thread));
|
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
|
||||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||||
kernel_names.push_back(fmt::format(
|
kernel_names.push_back(fmt::format(
|
||||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
|
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
|
||||||
lib_name(),
|
|
||||||
i,
|
|
||||||
work_per_thread));
|
|
||||||
kernel_names.push_back(fmt::format(
|
kernel_names.push_back(fmt::format(
|
||||||
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
|
"mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
|
||||||
lib_name(),
|
|
||||||
i,
|
|
||||||
work_per_thread));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -269,7 +309,6 @@ void Compiled::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Choose work per thread
|
// Choose work per thread
|
||||||
int work_per_thread = 4;
|
|
||||||
if (!contiguous && shape.back() % work_per_thread != 0) {
|
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||||
work_per_thread = 1;
|
work_per_thread = 1;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user