Replace fmt::format with std::format

This commit is contained in:
Angelos Katharopoulos 2025-01-17 11:24:16 -08:00
parent e300a01f4a
commit 0a41393dba
8 changed files with 69 additions and 78 deletions

View File

@ -223,10 +223,6 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>) $<INSTALL_INTERFACE:include>)
# FetchContent_Declare( fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git
# GIT_TAG 10.2.1 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(fmt)
# target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package( find_package(

View File

@ -2,13 +2,12 @@
#include <dlfcn.h> #include <dlfcn.h>
#include <filesystem> #include <filesystem>
#include <format>
#include <fstream> #include <fstream>
#include <list> #include <list>
#include <mutex> #include <mutex>
#include <shared_mutex> #include <shared_mutex>
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h" #include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/common/jit_compiler.h" #include "mlx/backend/common/jit_compiler.h"
@ -111,7 +110,7 @@ void* compile(
JitCompiler::exec(JitCompiler::build_command( JitCompiler::exec(JitCompiler::build_command(
output_dir, source_file_name, shared_lib_name)); output_dir, source_file_name, shared_lib_name));
} catch (const std::exception& error) { } catch (const std::exception& error) {
throw std::runtime_error(fmt::format( throw std::runtime_error(std::format(
"[Compile::eval_cpu] Failed to compile function {0}: {1}", "[Compile::eval_cpu] Failed to compile function {0}: {1}",
kernel_name, kernel_name,
error.what())); error.what()));

View File

@ -5,7 +5,7 @@
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include <fmt/format.h> #include <format>
namespace mlx::core { namespace mlx::core {
@ -33,7 +33,7 @@ struct VisualStudioInfo {
arch = "x64"; arch = "x64";
#endif #endif
// Get path of Visual Studio. // Get path of Visual Studio.
std::string vs_path = JitCompiler::exec(fmt::format( std::string vs_path = JitCompiler::exec(std::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -property installationPath", " -property installationPath",
std::getenv("ProgramFiles(x86)"))); std::getenv("ProgramFiles(x86)")));
@ -41,7 +41,7 @@ struct VisualStudioInfo {
throw std::runtime_error("Can not find Visual Studio."); throw std::runtime_error("Can not find Visual Studio.");
} }
// Read the envs from vcvarsall. // Read the envs from vcvarsall.
std::string envs = JitCompiler::exec(fmt::format( std::string envs = JitCompiler::exec(std::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
vs_path, vs_path,
arch)); arch));
@ -55,7 +55,7 @@ struct VisualStudioInfo {
if (name == "LIB") { if (name == "LIB") {
libpaths = str_split(value, ';'); libpaths = str_split(value, ';');
} else if (name == "VCToolsInstallDir") { } else if (name == "VCToolsInstallDir") {
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); cl_exe = std::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
} }
} }
} }
@ -81,9 +81,9 @@ std::string JitCompiler::build_command(
const VisualStudioInfo& info = GetVisualStudioInfo(); const VisualStudioInfo& info = GetVisualStudioInfo();
std::string libpaths; std::string libpaths;
for (const std::string& lib : info.libpaths) { for (const std::string& lib : info.libpaths) {
libpaths += fmt::format(" /libpath:\"{0}\"", lib); libpaths += std::format(" /libpath:\"{0}\"", lib);
} }
return fmt::format( return std::format(
"\"" "\""
"cd /D \"{0}\" && " "cd /D \"{0}\" && "
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" " "\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
@ -95,7 +95,7 @@ std::string JitCompiler::build_command(
shared_lib_name, shared_lib_name,
libpaths); libpaths);
#else #else
return fmt::format( return std::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1", "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1",
(dir / source_file_name).string(), (dir / source_file_name).string(),
(dir / shared_lib_name).string()); (dir / shared_lib_name).string());
@ -139,7 +139,7 @@ std::string JitCompiler::exec(const std::string& cmd) {
int code = WEXITSTATUS(status); int code = WEXITSTATUS(status);
#endif #endif
if (code != 0) { if (code != 0) {
throw std::runtime_error(fmt::format( throw std::runtime_error(std::format(
"Failed to execute command with return code {0}: \"{1}\", " "Failed to execute command with return code {0}: \"{1}\", "
"the output is: {2}", "the output is: {2}",
code, code,

View File

@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h> #include <format>
#include <sstream> #include <sstream>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
@ -11,8 +11,6 @@
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core { namespace mlx::core {
inline void build_kernel( inline void build_kernel(
@ -41,7 +39,7 @@ inline void build_kernel(
int cnt = 0; int cnt = 0;
// Start the kernel // Start the kernel
os += fmt::format( os += std::format(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
// Add the input arguments // Add the input arguments
@ -57,7 +55,7 @@ inline void build_kernel(
if (!is_scalar(x) && !contiguous) { if (!is_scalar(x) && !contiguous) {
add_indices = true; add_indices = true;
} }
os += fmt::format( os += std::format(
" device const {0}* {1} [[buffer({2})]],\n", " device const {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()), get_type_string(x.dtype()),
xname, xname,
@ -65,13 +63,13 @@ inline void build_kernel(
} }
if (add_indices) { if (add_indices) {
os += fmt::format( os += std::format(
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
} }
// Add the output arguments // Add the output arguments
for (auto& x : outputs) { for (auto& x : outputs) {
os += fmt::format( os += std::format(
" device {0}* {1} [[buffer({2})]],\n", " device {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()), get_type_string(x.dtype()),
namer.get_name(x), namer.get_name(x),
@ -79,13 +77,13 @@ inline void build_kernel(
} }
// Add output strides and shape to extract the indices. // Add output strides and shape to extract the indices.
if (!contiguous) { if (!contiguous) {
os += fmt::format( os += std::format(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format( os += std::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++); " constant const int* output_shape [[buffer({0})]],\n", cnt++);
} }
if (dynamic_dims) { if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); os += std::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
} }
// The thread index in the whole grid // The thread index in the whole grid
@ -98,15 +96,15 @@ inline void build_kernel(
// a third grid dimension // a third grid dimension
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
} else if (work_per_thread > 1) { } else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); os += std::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format( os += std::format(
" int xshape = output_shape[{0}];\n", " int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
os += fmt::format( os += std::format(
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n", " {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type); idx_type);
} else { } else {
os += fmt::format( os += std::format(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type); idx_type);
} }
@ -121,16 +119,16 @@ inline void build_kernel(
auto type_str = get_type_string(x.dtype()); auto type_str = get_type_string(x.dtype());
std::ostringstream ss; std::ostringstream ss;
print_constant(ss, x); print_constant(ss, x);
os += fmt::format( os += std::format(
" auto tmp_{0} = static_cast<{1}>({2});\n", " auto tmp_{0} = static_cast<{1}>({2});\n",
xname, xname,
get_type_string(x.dtype()), get_type_string(x.dtype()),
ss.str()); ss.str());
} else if (is_scalar(x)) { } else if (is_scalar(x)) {
os += fmt::format( os += std::format(
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname); " {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
} else if (contiguous) { } else if (contiguous) {
os += fmt::format( os += std::format(
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname); " {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
} else { } else {
nc_inputs.push_back(x); nc_inputs.push_back(x);
@ -140,30 +138,30 @@ inline void build_kernel(
// Initialize the indices for non-contiguous inputs // Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]); auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname); os += std::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) { if (ndim == 1) {
int offset = i * ndim; int offset = i * ndim;
os += os +=
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset); std::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
} else if (ndim == 2) { } else if (ndim == 2) {
int offset = i * ndim; int offset = i * ndim;
os += fmt::format( os += std::format(
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n", "elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
idx_type, idx_type,
offset); offset);
} else if (ndim == 3) { } else if (ndim == 3) {
int offset = i * ndim; int offset = i * ndim;
os += fmt::format( os += std::format(
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset); "elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
} else if (!dynamic_dims) { } else if (!dynamic_dims) {
int offset = (i + 1) * ndim; int offset = (i + 1) * ndim;
os += fmt::format( os += std::format(
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n", "N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
idx_type, idx_type,
offset - 1, offset - 1,
offset - 2); offset - 2);
} else { } else {
os += fmt::format( os += std::format(
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n", "N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
idx_type, idx_type,
i); i);
@ -175,18 +173,18 @@ inline void build_kernel(
if (dynamic_dims) { if (dynamic_dims) {
os += " for (int d = ndim - 3; d >= 0; --d) {\n"; os += " for (int d = ndim - 3; d >= 0; --d) {\n";
} else { } else {
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3); os += std::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
} }
os += " uint l = zpos % output_shape[d];\n"; os += " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]); auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" index_{0} += ", xname); os += std::format(" index_{0} += ", xname);
if (dynamic_dims) { if (dynamic_dims) {
os += os +=
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i); std::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
} else { } else {
os += os +=
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim); std::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
} }
} }
os += " zpos /= output_shape[d];\n }\n"; os += " zpos /= output_shape[d];\n }\n";
@ -202,16 +200,16 @@ inline void build_kernel(
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i]; auto& x = nc_inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
os += fmt::format( os += std::format(
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname); " {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
} }
// Actually write the computation // Actually write the computation
for (auto& x : tape) { for (auto& x : tape) {
os += fmt::format( os += std::format(
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x)); " {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
if (is_static_cast(x.primitive())) { if (is_static_cast(x.primitive())) {
os += fmt::format( os += std::format(
"static_cast<{0}>(tmp_{1});\n", "static_cast<{0}>(tmp_{1});\n",
get_type_string(x.dtype()), get_type_string(x.dtype()),
namer.get_name(x.inputs()[0])); namer.get_name(x.inputs()[0]));
@ -221,15 +219,15 @@ inline void build_kernel(
os += ss.str(); os += ss.str();
os += "()("; os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); os += std::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
} }
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back())); os += std::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
} }
} }
// Write the outputs from tmps // Write the outputs from tmps
for (auto& x : outputs) { for (auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); os += std::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
} }
// Increment indices and close per thread loop // Increment indices and close per thread loop
if (work_per_thread > 1) { if (work_per_thread > 1) {
@ -237,10 +235,10 @@ inline void build_kernel(
auto& x = nc_inputs[i]; auto& x = nc_inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (!dynamic_dims) { if (!dynamic_dims) {
os += fmt::format( os += std::format(
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1); " index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
} else { } else {
os += fmt::format( os += std::format(
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i); " index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
} }
} }

View File

@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h> #include <format>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
@ -20,9 +20,9 @@ std::pair<std::string, std::string> make_index_args(
std::ostringstream idx_args; std::ostringstream idx_args;
std::ostringstream idx_arr; std::ostringstream idx_arr;
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
idx_args << fmt::format( idx_args << std::format(
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i); "const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
idx_arr << fmt::format("idx{0}", i); idx_arr << std::format("idx{0}", i);
if (i < nidx - 1) { if (i < nidx - 1) {
idx_args << "\n"; idx_args << "\n";
idx_arr << ","; idx_arr << ",";
@ -59,7 +59,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
bool large = large_index || large_src || large_out; bool large = large_index || large_src || large_out;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string kernel_name = fmt::format( std::string kernel_name = std::format(
"gather{0}{1}_{2}_{3}_{4}", "gather{0}{1}_{2}_{3}_{4}",
type_to_name(out), type_to_name(out),
idx_type_name, idx_type_name,
@ -77,7 +77,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations // Index dimension specializations
kernel_source += fmt::format( kernel_source += std::format(
gather_kernels, gather_kernels,
type_to_name(out) + idx_type_name, type_to_name(out) + idx_type_name,
out_type_str, out_type_str,
@ -238,7 +238,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
bool large_idx = nidx && (inputs[1].size() > INT32_MAX); bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
bool large_upd = upd.size() > INT32_MAX; bool large_upd = upd.size() > INT32_MAX;
bool large = large_out || large_idx || large_upd; bool large = large_out || large_idx || large_upd;
std::string kernel_name = fmt::format( std::string kernel_name = std::format(
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}", "scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
type_to_name(out), type_to_name(out),
idx_type_name, idx_type_name,
@ -275,11 +275,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break; break;
} }
if (reduce_type_ != Scatter::None) { if (reduce_type_ != Scatter::None) {
op_type = fmt::format(fmt::runtime(op_type), out_type_str); op_type = std::vformat(op_type, std::make_format_args(out_type_str));
} }
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source += fmt::format( kernel_source += std::format(
scatter_kernels, scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name, type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str, out_type_str,

View File

@ -9,8 +9,6 @@
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
using namespace fmt::literals;
namespace mlx::core { namespace mlx::core {
std::string op_name(const array& arr) { std::string op_name(const array& arr) {
@ -26,7 +24,7 @@ MTL::ComputePipelineState* get_arange_kernel(
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::arange() kernel_source << metal::utils() << metal::arange()
<< fmt::format( << std::format(
arange_kernels, arange_kernels,
kernel_name, kernel_name,
get_type_string(out.dtype())); get_type_string(out.dtype()));
@ -259,7 +257,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
auto lib = d.get_library(lib_name, [&] { auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax() kernel_source << metal::utils() << metal::softmax()
<< fmt::format( << std::format(
softmax_kernels, softmax_kernels,
lib_name, lib_name,
get_type_string(out.dtype()), get_type_string(out.dtype()),
@ -445,7 +443,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused() << metal::steel_gemm_fused()
<< fmt::format( << std::format(
steel_gemm_fused_kernels, steel_gemm_fused_kernels,
"name"_a = lib_name, "name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()), "itype"_a = get_type_string(out.dtype()),
@ -480,7 +478,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk() << metal::steel_gemm_splitk()
<< fmt::format( << std::format(
steel_gemm_splitk_kernels, steel_gemm_splitk_kernels,
"name"_a = lib_name, "name"_a = lib_name,
"itype"_a = get_type_string(in.dtype()), "itype"_a = get_type_string(in.dtype()),
@ -510,13 +508,13 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk() << metal::steel_gemm_splitk()
<< fmt::format( << std::vformat(
fmt::runtime(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels), : steel_gemm_splitk_accum_kernels,
std::make_format_args(
"name"_a = lib_name, "name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()), "atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype())); "otype"_a = get_type_string(out.dtype())));
return kernel_source.str(); return kernel_source.str();
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
@ -547,7 +545,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemm() kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_masked() << metal::steel_gemm_masked()
<< fmt::format( << std::format(
steel_gemm_masked_kernels, steel_gemm_masked_kernels,
"name"_a = lib_name, "name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()), "itype"_a = get_type_string(out.dtype()),
@ -590,7 +588,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
auto op_mask_type = auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemv_masked() kernel_source << metal::utils() << metal::gemv_masked()
<< fmt::format( << std::format(
gemv_masked_kernel, gemv_masked_kernel,
"name"_a = lib_name, "name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()), "itype"_a = get_type_string(out.dtype()),
@ -624,7 +622,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv() kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format( << std::format(
steel_conv_kernels, steel_conv_kernels,
"name"_a = lib_name, "name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()), "itype"_a = get_type_string(out.dtype()),
@ -654,7 +652,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general() << metal::steel_conv_general()
<< fmt::format( << std::format(
steel_conv_general_kernels, steel_conv_general_kernels,
"name"_a = lib_name, "name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()), "itype"_a = get_type_string(out.dtype()),

View File

@ -1,6 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <fmt/format.h> #include <format>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@ -218,7 +218,7 @@ get_template_definition(std::string name, std::string func, Args... args) {
}; };
(add_arg(args), ...); (add_arg(args), ...);
s << ">"; s << ">";
return fmt::format( return std::format(
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n", "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
name, name,
s.str()); s.str());

View File

@ -66,7 +66,7 @@ array compute_dynamic_offset(
auto dtype = indices.dtype(); auto dtype = indices.dtype();
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype); std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
auto lib = d.get_library(lib_name, [dtype]() { auto lib = d.get_library(lib_name, [dtype]() {
return fmt::format( return std::format(
R"( R"(
[[kernel]] void compute_dynamic_offset_{0}( [[kernel]] void compute_dynamic_offset_{0}(
constant const {1}* indices [[buffer(0)]], constant const {1}* indices [[buffer(0)]],