mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Instantiate only vectorized contiguous kernels
This commit is contained in:
@@ -198,6 +198,7 @@ struct FusedKernelBuilder {
|
|||||||
}
|
}
|
||||||
os += " }\n";
|
os += " }\n";
|
||||||
|
|
||||||
|
// Store the output to global memory
|
||||||
for (const auto& x : outputs) {
|
for (const auto& x : outputs) {
|
||||||
os += fmt::format(
|
os += fmt::format(
|
||||||
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
|
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
|
||||||
@@ -249,11 +250,15 @@ void Compiled::eval_gpu(
|
|||||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||||
// Build kernel names.
|
// Build kernel names.
|
||||||
std::vector<std::string> kernel_names;
|
std::vector<std::string> kernel_names;
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
work_per_thread));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
work_per_thread));
|
||||||
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
|
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
|
||||||
kernel_names.push_back(fmt::format(
|
|
||||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>", lib_name(), wpt));
|
|
||||||
kernel_names.push_back(fmt::format(
|
|
||||||
"mlx::core::cu::{}_contiguous<int64_t, {}>", lib_name(), wpt));
|
|
||||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||||
kernel_names.push_back(fmt::format(
|
kernel_names.push_back(fmt::format(
|
||||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
|
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
|
||||||
|
|||||||
Reference in New Issue
Block a user