diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 814e41968..7f859b91a 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -198,6 +198,7 @@ struct FusedKernelBuilder { } os += " }\n"; + // Store the output to global memory for (const auto& x : outputs) { os += fmt::format( " store_vector({0} + index, 0, vec_{0}, size - index);\n", @@ -249,11 +250,15 @@ void Compiled::eval_gpu( builder.os += "\n} // namespace mlx::core::cu\n"; // Build kernel names. std::vector kernel_names; + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_contiguous", + lib_name(), + work_per_thread)); + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_contiguous", + lib_name(), + work_per_thread)); for (auto wpt : std::array{1, work_per_thread}) { - kernel_names.push_back(fmt::format( - "mlx::core::cu::{}_contiguous", lib_name(), wpt)); - kernel_names.push_back(fmt::format( - "mlx::core::cu::{}_contiguous", lib_name(), wpt)); for (int i = 1; i <= MAX_NDIM; ++i) { kernel_names.push_back(fmt::format( "mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));