From 32972a59249931f79ea5c7d551ec65a4ca1c173c Mon Sep 17 00:00:00 2001 From: xnorai <147757538+xnorai@users.noreply.github.com> Date: Thu, 24 Oct 2024 08:54:51 -0700 Subject: [PATCH] C++20 compatibility for fmt (#1519) * C++20 compatibility for fmt * Address review feedback * Remove stray string * Add newlines back --- mlx/backend/metal/indexing.cpp | 2 +- mlx/backend/metal/jit_kernels.cpp | 5 +++-- mlx/backend/metal/kernels.h | 8 ++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 37c511b39..7820b0272 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -262,7 +262,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { break; } if (reduce_type_ != Scatter::None) { - op_type = fmt::format(op_type, out_type_str); + op_type = fmt::format(fmt::runtime(op_type), out_type_str); } auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 430ff65af..ff567f0c0 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -445,8 +445,9 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() << fmt::format( - axbpy ? steel_gemm_splitk_accum_axbpy_kernels - : steel_gemm_splitk_accum_kernels, + fmt::runtime( + axbpy ? steel_gemm_splitk_accum_axbpy_kernels + : steel_gemm_splitk_accum_kernels), "name"_a = lib_name, "atype"_a = get_type_string(in.dtype()), "otype"_a = get_type_string(out.dtype())); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 2f861373d..d8a258cb8 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -209,10 +209,10 @@ get_template_definition(std::string name, std::string func, Args... args) { }; (add_arg(args), ...); s << ">"; - std::string base_string = R"( -template [[host_name("{0}")]] [[kernel]] decltype({1}) {1}; - )"; - return fmt::format(base_string, name, s.str()); + return fmt::format( + "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n", + name, + s.str()); } } // namespace mlx::core