Instantiate only vectorized contiguous kernels

This commit is contained in:
Angelos Katharopoulos
2025-07-30 13:04:02 -07:00
parent 46b01279f8
commit 25ad6ab443

View File

@@ -198,6 +198,7 @@ struct FusedKernelBuilder {
}
os += " }\n";
// Store the output to global memory
for (const auto& x : outputs) {
os += fmt::format(
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
@@ -249,11 +250,15 @@ void Compiled::eval_gpu(
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names;
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>", lib_name(), wpt));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>", lib_name(), wpt));
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));