mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Instantiate only vectorized contiguous kernels
This commit is contained in:
@@ -198,6 +198,7 @@ struct FusedKernelBuilder {
|
||||
}
|
||||
os += " }\n";
|
||||
|
||||
// Store the output to global memory
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(
|
||||
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
|
||||
@@ -249,11 +250,15 @@ void Compiled::eval_gpu(
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names;
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>", lib_name(), wpt));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<int64_t, {}>", lib_name(), wpt));
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
|
||||
|
||||
Reference in New Issue
Block a user