diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 37c511b39..7820b0272 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -262,7 +262,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { break; } if (reduce_type_ != Scatter::None) { - op_type = fmt::format(op_type, out_type_str); + op_type = fmt::format(fmt::runtime(op_type), out_type_str); } auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 430ff65af..ff567f0c0 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -445,8 +445,9 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() << fmt::format( - axbpy ? steel_gemm_splitk_accum_axbpy_kernels - : steel_gemm_splitk_accum_kernels, + fmt::runtime( + axbpy ? steel_gemm_splitk_accum_axbpy_kernels + : steel_gemm_splitk_accum_kernels), "name"_a = lib_name, "atype"_a = get_type_string(in.dtype()), "otype"_a = get_type_string(out.dtype())); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 2f861373d..d8a258cb8 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -209,10 +209,10 @@ get_template_definition(std::string name, std::string func, Args... args) { }; (add_arg(args), ...); s << ">"; - std::string base_string = R"( -template [[host_name("{0}")]] [[kernel]] decltype({1}) {1}; - )"; - return fmt::format(base_string, name, s.str()); + return fmt::format( + "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n", + name, + s.str()); } } // namespace mlx::core