From 0a41393dba2d5e285b50dde5a8f2d768526e288a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 17 Jan 2025 11:24:16 -0800 Subject: [PATCH] Replace fmt::format with std::format --- CMakeLists.txt | 4 -- mlx/backend/common/compiled_cpu.cpp | 5 +-- mlx/backend/common/jit_compiler.cpp | 16 +++---- mlx/backend/metal/compiled.cpp | 68 ++++++++++++++--------------- mlx/backend/metal/indexing.cpp | 16 +++---- mlx/backend/metal/jit_kernels.cpp | 32 +++++++------- mlx/backend/metal/kernels.h | 4 +- mlx/backend/metal/primitives.cpp | 2 +- 8 files changed, 69 insertions(+), 78 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5963c8442..85717744d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,10 +223,6 @@ target_include_directories( mlx PUBLIC $ $) -# FetchContent_Declare( fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git -# GIT_TAG 10.2.1 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(fmt) -# target_link_libraries(mlx PRIVATE $) - if(MLX_BUILD_PYTHON_BINDINGS) message(STATUS "Building Python bindings.") find_package( diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index e5c0156c8..c63bda73e 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -2,13 +2,12 @@ #include #include +#include #include #include #include #include -#include - #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled_preamble.h" #include "mlx/backend/common/jit_compiler.h" @@ -111,7 +110,7 @@ void* compile( JitCompiler::exec(JitCompiler::build_command( output_dir, source_file_name, shared_lib_name)); } catch (const std::exception& error) { - throw std::runtime_error(fmt::format( + throw std::runtime_error(std::format( "[Compile::eval_cpu] Failed to compile function {0}: {1}", kernel_name, error.what())); diff --git a/mlx/backend/common/jit_compiler.cpp b/mlx/backend/common/jit_compiler.cpp index 34d57138c..03d7caf7a 100644 --- a/mlx/backend/common/jit_compiler.cpp +++ b/mlx/backend/common/jit_compiler.cpp @@ -5,7 +5,7 @@ #include #include -#include +#include namespace mlx::core { @@ -33,7 +33,7 @@ struct VisualStudioInfo { arch = "x64"; #endif // Get path of Visual Studio. - std::string vs_path = JitCompiler::exec(fmt::format( + std::string vs_path = JitCompiler::exec(std::format( "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" " -property installationPath", std::getenv("ProgramFiles(x86)"))); @@ -41,7 +41,7 @@ struct VisualStudioInfo { throw std::runtime_error("Can not find Visual Studio."); } // Read the envs from vcvarsall. - std::string envs = JitCompiler::exec(fmt::format( + std::string envs = JitCompiler::exec(std::format( "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", vs_path, arch)); @@ -55,7 +55,7 @@ struct VisualStudioInfo { if (name == "LIB") { libpaths = str_split(value, ';'); } else if (name == "VCToolsInstallDir") { - cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); + cl_exe = std::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); } } } @@ -81,9 +81,9 @@ std::string JitCompiler::build_command( const VisualStudioInfo& info = GetVisualStudioInfo(); std::string libpaths; for (const std::string& lib : info.libpaths) { - libpaths += fmt::format(" /libpath:\"{0}\"", lib); + libpaths += std::format(" /libpath:\"{0}\"", lib); } - return fmt::format( + return std::format( "\"" "cd /D \"{0}\" && " "\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" " @@ -95,7 +95,7 @@ std::string JitCompiler::build_command( shared_lib_name, libpaths); #else - return fmt::format( + return std::format( "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1", (dir / source_file_name).string(), (dir / shared_lib_name).string()); @@ -139,7 +139,7 @@ std::string JitCompiler::exec(const std::string& cmd) { int code = WEXITSTATUS(status); #endif if (code != 0) { - throw std::runtime_error(fmt::format( + throw std::runtime_error(std::format( "Failed to execute command with return code {0}: \"{1}\", " "the output is: {2}", code, diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index e8e5e77e8..27a5640ac 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -1,5 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include +#include #include #include "mlx/backend/common/compiled.h" @@ -11,8 +11,6 @@ #include "mlx/primitives.h" #include "mlx/utils.h" -using namespace fmt::literals; - namespace mlx::core { inline void build_kernel( @@ -41,7 +39,7 @@ inline void build_kernel( int cnt = 0; // Start the kernel - os += fmt::format( + os += std::format( "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); // Add the input arguments @@ -57,7 +55,7 @@ inline void build_kernel( if (!is_scalar(x) && !contiguous) { add_indices = true; } - os += fmt::format( + os += std::format( " device const {0}* {1} [[buffer({2})]],\n", get_type_string(x.dtype()), xname, @@ -65,13 +63,13 @@ inline void build_kernel( } if (add_indices) { - os += fmt::format( + os += std::format( " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); } // Add the output arguments for (auto& x : outputs) { - os += fmt::format( + os += std::format( " device {0}* {1} [[buffer({2})]],\n", get_type_string(x.dtype()), namer.get_name(x), @@ -79,13 +77,13 @@ inline void build_kernel( } // Add output strides and shape to extract the indices. if (!contiguous) { - os += fmt::format( + os += std::format( " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); - os += fmt::format( + os += std::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); } if (dynamic_dims) { - os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); + os += std::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); } // The thread index in the whole grid @@ -98,15 +96,15 @@ inline void build_kernel( // a third grid dimension os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; } else if (work_per_thread > 1) { - os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); - os += fmt::format( + os += std::format(" constexpr int N_ = {0};\n", work_per_thread); + os += std::format( " int xshape = output_shape[{0}];\n", dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); - os += fmt::format( + os += std::format( " {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } else { - os += fmt::format( + os += std::format( " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } @@ -121,16 +119,16 @@ inline void build_kernel( auto type_str = get_type_string(x.dtype()); std::ostringstream ss; print_constant(ss, x); - os += fmt::format( + os += std::format( " auto tmp_{0} = static_cast<{1}>({2});\n", xname, get_type_string(x.dtype()), ss.str()); } else if (is_scalar(x)) { - os += fmt::format( + os += std::format( " {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname); } else if (contiguous) { - os += fmt::format( + os += std::format( " {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname); } else { nc_inputs.push_back(x); @@ -140,30 +138,30 @@ inline void build_kernel( // Initialize the indices for non-contiguous inputs for (int i = 0; i < nc_inputs.size(); ++i) { auto& xname = namer.get_name(nc_inputs[i]); - os += fmt::format(" {0} index_{1} = ", idx_type, xname); + os += std::format(" {0} index_{1} = ", idx_type, xname); if (ndim == 1) { int offset = i * ndim; os += - fmt::format("elem_to_loc_1(pos.x, in_strides[{0}]);\n", offset); + std::format("elem_to_loc_1(pos.x, in_strides[{0}]);\n", offset); } else if (ndim == 2) { int offset = i * ndim; - os += fmt::format( + os += std::format( "elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n", idx_type, offset); } else if (ndim == 3) { int offset = i * ndim; - os += fmt::format( + os += std::format( "elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset); } else if (!dynamic_dims) { int offset = (i + 1) * ndim; - os += fmt::format( + os += std::format( "N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n", idx_type, offset - 1, offset - 2); } else { - os += fmt::format( + os += std::format( "N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n", idx_type, i); @@ -175,18 +173,18 @@ inline void build_kernel( if (dynamic_dims) { os += " for (int d = ndim - 3; d >= 0; --d) {\n"; } else { - os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3); + os += std::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3); } os += " uint l = zpos % output_shape[d];\n"; for (int i = 0; i < nc_inputs.size(); ++i) { auto& xname = namer.get_name(nc_inputs[i]); - os += fmt::format(" index_{0} += ", xname); + os += std::format(" index_{0} += ", xname); if (dynamic_dims) { os += - fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i); + std::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i); } else { os += - fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim); + std::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim); } } os += " zpos /= output_shape[d];\n }\n"; @@ -202,16 +200,16 @@ inline void build_kernel( for (int i = 0; i < nc_inputs.size(); ++i) { auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); - os += fmt::format( + os += std::format( " {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname); } // Actually write the computation for (auto& x : tape) { - os += fmt::format( + os += std::format( " {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x)); if (is_static_cast(x.primitive())) { - os += fmt::format( + os += std::format( "static_cast<{0}>(tmp_{1});\n", get_type_string(x.dtype()), namer.get_name(x.inputs()[0])); @@ -221,15 +219,15 @@ inline void build_kernel( os += ss.str(); os += "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { - os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); + os += std::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); } - os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back())); + os += std::format("tmp_{0});\n", namer.get_name(x.inputs().back())); } } // Write the outputs from tmps for (auto& x : outputs) { - os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); + os += std::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); } // Increment indices and close per thread loop if (work_per_thread > 1) { @@ -237,10 +235,10 @@ inline void build_kernel( auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); if (!dynamic_dims) { - os += fmt::format( + os += std::format( " index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1); } else { - os += fmt::format( + os += std::format( " index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i); } } diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index b27765a40..f9823ce3e 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -1,5 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include +#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" @@ -20,9 +20,9 @@ std::pair make_index_args( std::ostringstream idx_args; std::ostringstream idx_arr; for (int i = 0; i < nidx; ++i) { - idx_args << fmt::format( + idx_args << std::format( "const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i); - idx_arr << fmt::format("idx{0}", i); + idx_arr << std::format("idx{0}", i); if (i < nidx - 1) { idx_args << "\n"; idx_arr << ","; @@ -59,7 +59,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { bool large = large_index || large_src || large_out; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; - std::string kernel_name = fmt::format( + std::string kernel_name = std::format( "gather{0}{1}_{2}_{3}_{4}", type_to_name(out), idx_type_name, @@ -77,7 +77,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); // Index dimension specializations - kernel_source += fmt::format( + kernel_source += std::format( gather_kernels, type_to_name(out) + idx_type_name, out_type_str, @@ -238,7 +238,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { bool large_idx = nidx && (inputs[1].size() > INT32_MAX); bool large_upd = upd.size() > INT32_MAX; bool large = large_out || large_idx || large_upd; - std::string kernel_name = fmt::format( + std::string kernel_name = std::format( "scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}", type_to_name(out), idx_type_name, @@ -275,11 +275,11 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { break; } if (reduce_type_ != Scatter::None) { - op_type = fmt::format(fmt::runtime(op_type), out_type_str); + op_type = std::vformat(op_type, std::make_format_args(out_type_str)); } auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); - kernel_source += fmt::format( + kernel_source += std::format( scatter_kernels, type_to_name(out) + idx_type_name + "_" + op_name, out_type_str, diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 78560bb2a..3bba92056 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -9,8 +9,6 @@ #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" -using namespace fmt::literals; - namespace mlx::core { std::string op_name(const array& arr) { @@ -26,7 +24,7 @@ MTL::ComputePipelineState* get_arange_kernel( auto lib = d.get_library(kernel_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::arange() - << fmt::format( + << std::format( arange_kernels, kernel_name, get_type_string(out.dtype())); @@ -259,7 +257,7 @@ MTL::ComputePipelineState* get_softmax_kernel( auto lib = d.get_library(lib_name, [&] { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::softmax() - << fmt::format( + << std::format( softmax_kernels, lib_name, get_type_string(out.dtype()), @@ -445,7 +443,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel( std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_fused() - << fmt::format( + << std::format( steel_gemm_fused_kernels, "name"_a = lib_name, "itype"_a = get_type_string(out.dtype()), @@ -480,7 +478,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() - << fmt::format( + << std::format( steel_gemm_splitk_kernels, "name"_a = lib_name, "itype"_a = get_type_string(in.dtype()), @@ -510,13 +508,13 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() - << fmt::format( - fmt::runtime( - axbpy ? steel_gemm_splitk_accum_axbpy_kernels - : steel_gemm_splitk_accum_kernels), - "name"_a = lib_name, - "atype"_a = get_type_string(in.dtype()), - "otype"_a = get_type_string(out.dtype())); + << std::vformat( + axbpy ? steel_gemm_splitk_accum_axbpy_kernels + : steel_gemm_splitk_accum_kernels, + std::make_format_args( + "name"_a = lib_name, + "atype"_a = get_type_string(in.dtype()), + "otype"_a = get_type_string(out.dtype()))); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); @@ -547,7 +545,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_masked() - << fmt::format( + << std::format( steel_gemm_masked_kernels, "name"_a = lib_name, "itype"_a = get_type_string(out.dtype()), @@ -590,7 +588,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel( auto op_mask_type = mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; kernel_source << metal::utils() << metal::gemv_masked() - << fmt::format( + << std::format( gemv_masked_kernel, "name"_a = lib_name, "itype"_a = get_type_string(out.dtype()), @@ -624,7 +622,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel( auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv() - << fmt::format( + << std::format( steel_conv_kernels, "name"_a = lib_name, "itype"_a = get_type_string(out.dtype()), @@ -654,7 +652,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv_general() - << fmt::format( + << std::format( steel_conv_general_kernels, "name"_a = lib_name, "itype"_a = get_type_string(out.dtype()), diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index dd6213754..7456ef84e 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -1,6 +1,6 @@ // Copyright © 2024 Apple Inc. -#include +#include #include "mlx/array.h" #include "mlx/backend/metal/device.h" @@ -218,7 +218,7 @@ get_template_definition(std::string name, std::string func, Args... args) { }; (add_arg(args), ...); s << ">"; - return fmt::format( + return std::format( "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n", name, s.str()); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 47190daf3..f5d391331 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -66,7 +66,7 @@ array compute_dynamic_offset( auto dtype = indices.dtype(); std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype); auto lib = d.get_library(lib_name, [dtype]() { - return fmt::format( + return std::format( R"( [[kernel]] void compute_dynamic_offset_{0}( constant const {1}* indices [[buffer(0)]],