diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 6eda2533f4..7f859b91a0 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -104,10 +104,41 @@ struct FusedKernelBuilder { " }\n"; } + // Vectorized read loop + if (contiguous) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + if (is_scalar(x) || is_constant(i)) { + continue; + } + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + os += fmt::format( + " auto vec_{0} = load_vector({0} + index, 0, size - index, 0);\n", + xname, + type); + } + } + + // Create some space for the outputs + for (const auto& x : outputs) { + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + os += fmt::format( + " AlignedVector<{}, work_per_thread> vec_{};\n", type, xname); + } + // Work loop - os += - "\n" - " for (int i = 0; i < work_per_thread && index < size; i++) {\n"; + if (!contiguous) { + os += + "\n" + " for (int i = 0; i < work_per_thread && index < size; i++) {\n"; + } else { + os += + "\n" + " #pragma unroll\n" + " for (int i = 0; i < work_per_thread; i++) {\n"; + } // Read inputs. for (size_t i = 0; i < inputs.size(); ++i) { @@ -122,7 +153,7 @@ struct FusedKernelBuilder { } else if (is_scalar(x)) { value = fmt::format("{}[0]", xname); } else if (contiguous) { - value = fmt::format("{}[index]", xname); + value = fmt::format("vec_{}[i]", xname); } else { value = fmt::format("{}[{}_idx]", xname, xname); } @@ -150,25 +181,30 @@ struct FusedKernelBuilder { // Write output. for (const auto& x : outputs) { - os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); + os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x)); } // End of work loop - os += - "\n" - " index++;\n"; if (!contiguous) { + os += "\n"; for (size_t i = 0; i < inputs.size(); ++i) { const auto& x = inputs[i]; const std::string& xname = namer.get_name(x); if (is_scalar(x) || is_constant(i)) { continue; } - os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n"; + os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname); } } os += " }\n"; + // Store the output to global memory + for (const auto& x : outputs) { + os += fmt::format( + " store_vector({0} + index, 0, vec_{0}, size - index);\n", + namer.get_name(x)); + } + os += "}\n"; } }; @@ -192,6 +228,15 @@ void Compiled::eval_gpu( nvtx3::scoped_range r("Compiled::eval_gpu"); auto& s = stream(); + // Determine the work per thread for the vectorized reads/writes. We take it + // as 16 over the max itemsize for the outputs. Another heuristic could be + // over the max itemsize of all arrays. + int max_size = 1; + for (const auto& x : outputs) { + max_size = (max_size > x.itemsize()) ? max_size : x.itemsize(); + } + int work_per_thread = 16 / max_size; + cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() { // Build source code. cu::FusedKernelBuilder builder{ @@ -205,28 +250,23 @@ void Compiled::eval_gpu( builder.os += "\n} // namespace mlx::core::cu\n"; // Build kernel names. std::vector kernel_names; - for (auto work_per_thread : std::array{1, 4}) { - kernel_names.push_back(fmt::format( - "mlx::core::cu::{}_contiguous", - lib_name(), - work_per_thread)); - kernel_names.push_back(fmt::format( - "mlx::core::cu::{}_contiguous", - lib_name(), - work_per_thread)); + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_contiguous", + lib_name(), + work_per_thread)); + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_contiguous", + lib_name(), + work_per_thread)); + for (auto wpt : std::array{1, work_per_thread}) { for (int i = 1; i <= MAX_NDIM; ++i) { kernel_names.push_back(fmt::format( - "mlx::core::cu::{}_strided<{}, uint32_t, {}>", - lib_name(), - i, - work_per_thread)); + "mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt)); kernel_names.push_back(fmt::format( - "mlx::core::cu::{}_strided<{}, int64_t, {}>", - lib_name(), - i, - work_per_thread)); + "mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt)); } } + return std::make_pair(std::move(builder.os), std::move(kernel_names)); }); @@ -269,7 +309,6 @@ void Compiled::eval_gpu( } // Choose work per thread - int work_per_thread = 4; if (!contiguous && shape.back() % work_per_thread != 0) { work_per_thread = 1; }