mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Replace fmt::format with std::format
This commit is contained in:
		@@ -223,10 +223,6 @@ target_include_directories(
 | 
			
		||||
  mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
 | 
			
		||||
             $<INSTALL_INTERFACE:include>)
 | 
			
		||||
 | 
			
		||||
# FetchContent_Declare( fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git
 | 
			
		||||
# GIT_TAG 10.2.1 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(fmt)
 | 
			
		||||
# target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
 | 
			
		||||
 | 
			
		||||
if(MLX_BUILD_PYTHON_BINDINGS)
 | 
			
		||||
  message(STATUS "Building Python bindings.")
 | 
			
		||||
  find_package(
 | 
			
		||||
 
 | 
			
		||||
@@ -2,13 +2,12 @@
 | 
			
		||||
 | 
			
		||||
#include <dlfcn.h>
 | 
			
		||||
#include <filesystem>
 | 
			
		||||
#include <format>
 | 
			
		||||
#include <fstream>
 | 
			
		||||
#include <list>
 | 
			
		||||
#include <mutex>
 | 
			
		||||
#include <shared_mutex>
 | 
			
		||||
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/common/compiled.h"
 | 
			
		||||
#include "mlx/backend/common/compiled_preamble.h"
 | 
			
		||||
#include "mlx/backend/common/jit_compiler.h"
 | 
			
		||||
@@ -111,7 +110,7 @@ void* compile(
 | 
			
		||||
      JitCompiler::exec(JitCompiler::build_command(
 | 
			
		||||
          output_dir, source_file_name, shared_lib_name));
 | 
			
		||||
    } catch (const std::exception& error) {
 | 
			
		||||
      throw std::runtime_error(fmt::format(
 | 
			
		||||
      throw std::runtime_error(std::format(
 | 
			
		||||
          "[Compile::eval_cpu] Failed to compile function {0}: {1}",
 | 
			
		||||
          kernel_name,
 | 
			
		||||
          error.what()));
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@
 | 
			
		||||
#include <sstream>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
#include <format>
 | 
			
		||||
 | 
			
		||||
namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
@@ -33,7 +33,7 @@ struct VisualStudioInfo {
 | 
			
		||||
    arch = "x64";
 | 
			
		||||
#endif
 | 
			
		||||
    // Get path of Visual Studio.
 | 
			
		||||
    std::string vs_path = JitCompiler::exec(fmt::format(
 | 
			
		||||
    std::string vs_path = JitCompiler::exec(std::format(
 | 
			
		||||
        "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
 | 
			
		||||
        " -property installationPath",
 | 
			
		||||
        std::getenv("ProgramFiles(x86)")));
 | 
			
		||||
@@ -41,7 +41,7 @@ struct VisualStudioInfo {
 | 
			
		||||
      throw std::runtime_error("Can not find Visual Studio.");
 | 
			
		||||
    }
 | 
			
		||||
    // Read the envs from vcvarsall.
 | 
			
		||||
    std::string envs = JitCompiler::exec(fmt::format(
 | 
			
		||||
    std::string envs = JitCompiler::exec(std::format(
 | 
			
		||||
        "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
 | 
			
		||||
        vs_path,
 | 
			
		||||
        arch));
 | 
			
		||||
@@ -55,7 +55,7 @@ struct VisualStudioInfo {
 | 
			
		||||
      if (name == "LIB") {
 | 
			
		||||
        libpaths = str_split(value, ';');
 | 
			
		||||
      } else if (name == "VCToolsInstallDir") {
 | 
			
		||||
        cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
 | 
			
		||||
        cl_exe = std::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@@ -81,9 +81,9 @@ std::string JitCompiler::build_command(
 | 
			
		||||
  const VisualStudioInfo& info = GetVisualStudioInfo();
 | 
			
		||||
  std::string libpaths;
 | 
			
		||||
  for (const std::string& lib : info.libpaths) {
 | 
			
		||||
    libpaths += fmt::format(" /libpath:\"{0}\"", lib);
 | 
			
		||||
    libpaths += std::format(" /libpath:\"{0}\"", lib);
 | 
			
		||||
  }
 | 
			
		||||
  return fmt::format(
 | 
			
		||||
  return std::format(
 | 
			
		||||
      "\""
 | 
			
		||||
      "cd /D \"{0}\" && "
 | 
			
		||||
      "\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
 | 
			
		||||
@@ -95,7 +95,7 @@ std::string JitCompiler::build_command(
 | 
			
		||||
      shared_lib_name,
 | 
			
		||||
      libpaths);
 | 
			
		||||
#else
 | 
			
		||||
  return fmt::format(
 | 
			
		||||
  return std::format(
 | 
			
		||||
      "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1",
 | 
			
		||||
      (dir / source_file_name).string(),
 | 
			
		||||
      (dir / shared_lib_name).string());
 | 
			
		||||
@@ -139,7 +139,7 @@ std::string JitCompiler::exec(const std::string& cmd) {
 | 
			
		||||
  int code = WEXITSTATUS(status);
 | 
			
		||||
#endif
 | 
			
		||||
  if (code != 0) {
 | 
			
		||||
    throw std::runtime_error(fmt::format(
 | 
			
		||||
    throw std::runtime_error(std::format(
 | 
			
		||||
        "Failed to execute command with return code {0}: \"{1}\", "
 | 
			
		||||
        "the output is: {2}",
 | 
			
		||||
        code,
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
#include <format>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/common/compiled.h"
 | 
			
		||||
@@ -11,8 +11,6 @@
 | 
			
		||||
#include "mlx/primitives.h"
 | 
			
		||||
#include "mlx/utils.h"
 | 
			
		||||
 | 
			
		||||
using namespace fmt::literals;
 | 
			
		||||
 | 
			
		||||
namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
inline void build_kernel(
 | 
			
		||||
@@ -41,7 +39,7 @@ inline void build_kernel(
 | 
			
		||||
  int cnt = 0;
 | 
			
		||||
 | 
			
		||||
  // Start the kernel
 | 
			
		||||
  os += fmt::format(
 | 
			
		||||
  os += std::format(
 | 
			
		||||
      "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
 | 
			
		||||
 | 
			
		||||
  // Add the input arguments
 | 
			
		||||
@@ -57,7 +55,7 @@ inline void build_kernel(
 | 
			
		||||
    if (!is_scalar(x) && !contiguous) {
 | 
			
		||||
      add_indices = true;
 | 
			
		||||
    }
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "    device const {0}* {1} [[buffer({2})]],\n",
 | 
			
		||||
        get_type_string(x.dtype()),
 | 
			
		||||
        xname,
 | 
			
		||||
@@ -65,13 +63,13 @@ inline void build_kernel(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (add_indices) {
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "    constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Add the output arguments
 | 
			
		||||
  for (auto& x : outputs) {
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "    device {0}* {1} [[buffer({2})]],\n",
 | 
			
		||||
        get_type_string(x.dtype()),
 | 
			
		||||
        namer.get_name(x),
 | 
			
		||||
@@ -79,13 +77,13 @@ inline void build_kernel(
 | 
			
		||||
  }
 | 
			
		||||
  // Add output strides and shape to extract the indices.
 | 
			
		||||
  if (!contiguous) {
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "    constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "    constant const int* output_shape [[buffer({0})]],\n", cnt++);
 | 
			
		||||
  }
 | 
			
		||||
  if (dynamic_dims) {
 | 
			
		||||
    os += fmt::format("    constant const int& ndim [[buffer({0})]],\n", cnt++);
 | 
			
		||||
    os += std::format("    constant const int& ndim [[buffer({0})]],\n", cnt++);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // The thread index in the whole grid
 | 
			
		||||
@@ -98,15 +96,15 @@ inline void build_kernel(
 | 
			
		||||
    // a third grid dimension
 | 
			
		||||
    os += "  int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
 | 
			
		||||
  } else if (work_per_thread > 1) {
 | 
			
		||||
    os += fmt::format("  constexpr int N_ = {0};\n", work_per_thread);
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format("  constexpr int N_ = {0};\n", work_per_thread);
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "  int xshape = output_shape[{0}];\n",
 | 
			
		||||
        dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "  {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
 | 
			
		||||
        idx_type);
 | 
			
		||||
  } else {
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "  {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
 | 
			
		||||
        idx_type);
 | 
			
		||||
  }
 | 
			
		||||
@@ -121,16 +119,16 @@ inline void build_kernel(
 | 
			
		||||
      auto type_str = get_type_string(x.dtype());
 | 
			
		||||
      std::ostringstream ss;
 | 
			
		||||
      print_constant(ss, x);
 | 
			
		||||
      os += fmt::format(
 | 
			
		||||
      os += std::format(
 | 
			
		||||
          "  auto tmp_{0} = static_cast<{1}>({2});\n",
 | 
			
		||||
          xname,
 | 
			
		||||
          get_type_string(x.dtype()),
 | 
			
		||||
          ss.str());
 | 
			
		||||
    } else if (is_scalar(x)) {
 | 
			
		||||
      os += fmt::format(
 | 
			
		||||
      os += std::format(
 | 
			
		||||
          "  {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
 | 
			
		||||
    } else if (contiguous) {
 | 
			
		||||
      os += fmt::format(
 | 
			
		||||
      os += std::format(
 | 
			
		||||
          "  {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
 | 
			
		||||
    } else {
 | 
			
		||||
      nc_inputs.push_back(x);
 | 
			
		||||
@@ -140,30 +138,30 @@ inline void build_kernel(
 | 
			
		||||
  // Initialize the indices for non-contiguous inputs
 | 
			
		||||
  for (int i = 0; i < nc_inputs.size(); ++i) {
 | 
			
		||||
    auto& xname = namer.get_name(nc_inputs[i]);
 | 
			
		||||
    os += fmt::format("  {0} index_{1} = ", idx_type, xname);
 | 
			
		||||
    os += std::format("  {0} index_{1} = ", idx_type, xname);
 | 
			
		||||
    if (ndim == 1) {
 | 
			
		||||
      int offset = i * ndim;
 | 
			
		||||
      os +=
 | 
			
		||||
          fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
 | 
			
		||||
          std::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
 | 
			
		||||
    } else if (ndim == 2) {
 | 
			
		||||
      int offset = i * ndim;
 | 
			
		||||
      os += fmt::format(
 | 
			
		||||
      os += std::format(
 | 
			
		||||
          "elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
 | 
			
		||||
          idx_type,
 | 
			
		||||
          offset);
 | 
			
		||||
    } else if (ndim == 3) {
 | 
			
		||||
      int offset = i * ndim;
 | 
			
		||||
      os += fmt::format(
 | 
			
		||||
      os += std::format(
 | 
			
		||||
          "elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
 | 
			
		||||
    } else if (!dynamic_dims) {
 | 
			
		||||
      int offset = (i + 1) * ndim;
 | 
			
		||||
      os += fmt::format(
 | 
			
		||||
      os += std::format(
 | 
			
		||||
          "N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
 | 
			
		||||
          idx_type,
 | 
			
		||||
          offset - 1,
 | 
			
		||||
          offset - 2);
 | 
			
		||||
    } else {
 | 
			
		||||
      os += fmt::format(
 | 
			
		||||
      os += std::format(
 | 
			
		||||
          "N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
 | 
			
		||||
          idx_type,
 | 
			
		||||
          i);
 | 
			
		||||
@@ -175,18 +173,18 @@ inline void build_kernel(
 | 
			
		||||
    if (dynamic_dims) {
 | 
			
		||||
      os += "  for (int d = ndim - 3; d >= 0; --d) {\n";
 | 
			
		||||
    } else {
 | 
			
		||||
      os += fmt::format("  for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
 | 
			
		||||
      os += std::format("  for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
 | 
			
		||||
    }
 | 
			
		||||
    os += "    uint l = zpos % output_shape[d];\n";
 | 
			
		||||
    for (int i = 0; i < nc_inputs.size(); ++i) {
 | 
			
		||||
      auto& xname = namer.get_name(nc_inputs[i]);
 | 
			
		||||
      os += fmt::format("    index_{0} += ", xname);
 | 
			
		||||
      os += std::format("    index_{0} += ", xname);
 | 
			
		||||
      if (dynamic_dims) {
 | 
			
		||||
        os +=
 | 
			
		||||
            fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
 | 
			
		||||
            std::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
 | 
			
		||||
      } else {
 | 
			
		||||
        os +=
 | 
			
		||||
            fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
 | 
			
		||||
            std::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    os += "    zpos /= output_shape[d];\n  }\n";
 | 
			
		||||
@@ -202,16 +200,16 @@ inline void build_kernel(
 | 
			
		||||
  for (int i = 0; i < nc_inputs.size(); ++i) {
 | 
			
		||||
    auto& x = nc_inputs[i];
 | 
			
		||||
    auto& xname = namer.get_name(x);
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "  {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Actually write the computation
 | 
			
		||||
  for (auto& x : tape) {
 | 
			
		||||
    os += fmt::format(
 | 
			
		||||
    os += std::format(
 | 
			
		||||
        "  {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
 | 
			
		||||
    if (is_static_cast(x.primitive())) {
 | 
			
		||||
      os += fmt::format(
 | 
			
		||||
      os += std::format(
 | 
			
		||||
          "static_cast<{0}>(tmp_{1});\n",
 | 
			
		||||
          get_type_string(x.dtype()),
 | 
			
		||||
          namer.get_name(x.inputs()[0]));
 | 
			
		||||
@@ -221,15 +219,15 @@ inline void build_kernel(
 | 
			
		||||
      os += ss.str();
 | 
			
		||||
      os += "()(";
 | 
			
		||||
      for (int i = 0; i < x.inputs().size() - 1; i++) {
 | 
			
		||||
        os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
 | 
			
		||||
        os += std::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
 | 
			
		||||
      }
 | 
			
		||||
      os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
 | 
			
		||||
      os += std::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Write the outputs from tmps
 | 
			
		||||
  for (auto& x : outputs) {
 | 
			
		||||
    os += fmt::format("  {0}[index] = tmp_{0};\n", namer.get_name(x));
 | 
			
		||||
    os += std::format("  {0}[index] = tmp_{0};\n", namer.get_name(x));
 | 
			
		||||
  }
 | 
			
		||||
  // Increment indices and close per thread loop
 | 
			
		||||
  if (work_per_thread > 1) {
 | 
			
		||||
@@ -237,10 +235,10 @@ inline void build_kernel(
 | 
			
		||||
      auto& x = nc_inputs[i];
 | 
			
		||||
      auto& xname = namer.get_name(x);
 | 
			
		||||
      if (!dynamic_dims) {
 | 
			
		||||
        os += fmt::format(
 | 
			
		||||
        os += std::format(
 | 
			
		||||
            "  index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
 | 
			
		||||
      } else {
 | 
			
		||||
        os += fmt::format(
 | 
			
		||||
        os += std::format(
 | 
			
		||||
            "  index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
#include <format>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/common/compiled.h"
 | 
			
		||||
#include "mlx/backend/metal/copy.h"
 | 
			
		||||
@@ -20,9 +20,9 @@ std::pair<std::string, std::string> make_index_args(
 | 
			
		||||
  std::ostringstream idx_args;
 | 
			
		||||
  std::ostringstream idx_arr;
 | 
			
		||||
  for (int i = 0; i < nidx; ++i) {
 | 
			
		||||
    idx_args << fmt::format(
 | 
			
		||||
    idx_args << std::format(
 | 
			
		||||
        "const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
 | 
			
		||||
    idx_arr << fmt::format("idx{0}", i);
 | 
			
		||||
    idx_arr << std::format("idx{0}", i);
 | 
			
		||||
    if (i < nidx - 1) {
 | 
			
		||||
      idx_args << "\n";
 | 
			
		||||
      idx_arr << ",";
 | 
			
		||||
@@ -59,7 +59,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  bool large = large_index || large_src || large_out;
 | 
			
		||||
 | 
			
		||||
  std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
 | 
			
		||||
  std::string kernel_name = fmt::format(
 | 
			
		||||
  std::string kernel_name = std::format(
 | 
			
		||||
      "gather{0}{1}_{2}_{3}_{4}",
 | 
			
		||||
      type_to_name(out),
 | 
			
		||||
      idx_type_name,
 | 
			
		||||
@@ -77,7 +77,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
    auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
 | 
			
		||||
 | 
			
		||||
    // Index dimension specializations
 | 
			
		||||
    kernel_source += fmt::format(
 | 
			
		||||
    kernel_source += std::format(
 | 
			
		||||
        gather_kernels,
 | 
			
		||||
        type_to_name(out) + idx_type_name,
 | 
			
		||||
        out_type_str,
 | 
			
		||||
@@ -238,7 +238,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
 | 
			
		||||
  bool large_upd = upd.size() > INT32_MAX;
 | 
			
		||||
  bool large = large_out || large_idx || large_upd;
 | 
			
		||||
  std::string kernel_name = fmt::format(
 | 
			
		||||
  std::string kernel_name = std::format(
 | 
			
		||||
      "scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
 | 
			
		||||
      type_to_name(out),
 | 
			
		||||
      idx_type_name,
 | 
			
		||||
@@ -275,11 +275,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
    if (reduce_type_ != Scatter::None) {
 | 
			
		||||
      op_type = fmt::format(fmt::runtime(op_type), out_type_str);
 | 
			
		||||
      op_type = std::vformat(op_type, std::make_format_args(out_type_str));
 | 
			
		||||
    }
 | 
			
		||||
    auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
 | 
			
		||||
 | 
			
		||||
    kernel_source += fmt::format(
 | 
			
		||||
    kernel_source += std::format(
 | 
			
		||||
        scatter_kernels,
 | 
			
		||||
        type_to_name(out) + idx_type_name + "_" + op_name,
 | 
			
		||||
        out_type_str,
 | 
			
		||||
 
 | 
			
		||||
@@ -9,8 +9,6 @@
 | 
			
		||||
#include "mlx/backend/metal/kernels.h"
 | 
			
		||||
#include "mlx/backend/metal/utils.h"
 | 
			
		||||
 | 
			
		||||
using namespace fmt::literals;
 | 
			
		||||
 | 
			
		||||
namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
std::string op_name(const array& arr) {
 | 
			
		||||
@@ -26,7 +24,7 @@ MTL::ComputePipelineState* get_arange_kernel(
 | 
			
		||||
  auto lib = d.get_library(kernel_name, [&]() {
 | 
			
		||||
    std::ostringstream kernel_source;
 | 
			
		||||
    kernel_source << metal::utils() << metal::arange()
 | 
			
		||||
                  << fmt::format(
 | 
			
		||||
                  << std::format(
 | 
			
		||||
                         arange_kernels,
 | 
			
		||||
                         kernel_name,
 | 
			
		||||
                         get_type_string(out.dtype()));
 | 
			
		||||
@@ -259,7 +257,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
 | 
			
		||||
  auto lib = d.get_library(lib_name, [&] {
 | 
			
		||||
    std::ostringstream kernel_source;
 | 
			
		||||
    kernel_source << metal::utils() << metal::softmax()
 | 
			
		||||
                  << fmt::format(
 | 
			
		||||
                  << std::format(
 | 
			
		||||
                         softmax_kernels,
 | 
			
		||||
                         lib_name,
 | 
			
		||||
                         get_type_string(out.dtype()),
 | 
			
		||||
@@ -445,7 +443,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
 | 
			
		||||
    std::ostringstream kernel_source;
 | 
			
		||||
    kernel_source << metal::utils() << metal::gemm()
 | 
			
		||||
                  << metal::steel_gemm_fused()
 | 
			
		||||
                  << fmt::format(
 | 
			
		||||
                  << std::format(
 | 
			
		||||
                         steel_gemm_fused_kernels,
 | 
			
		||||
                         "name"_a = lib_name,
 | 
			
		||||
                         "itype"_a = get_type_string(out.dtype()),
 | 
			
		||||
@@ -480,7 +478,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
 | 
			
		||||
    std::ostringstream kernel_source;
 | 
			
		||||
    kernel_source << metal::utils() << metal::gemm()
 | 
			
		||||
                  << metal::steel_gemm_splitk()
 | 
			
		||||
                  << fmt::format(
 | 
			
		||||
                  << std::format(
 | 
			
		||||
                         steel_gemm_splitk_kernels,
 | 
			
		||||
                         "name"_a = lib_name,
 | 
			
		||||
                         "itype"_a = get_type_string(in.dtype()),
 | 
			
		||||
@@ -510,13 +508,13 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
 | 
			
		||||
    std::ostringstream kernel_source;
 | 
			
		||||
    kernel_source << metal::utils() << metal::gemm()
 | 
			
		||||
                  << metal::steel_gemm_splitk()
 | 
			
		||||
                  << fmt::format(
 | 
			
		||||
                         fmt::runtime(
 | 
			
		||||
                             axbpy ? steel_gemm_splitk_accum_axbpy_kernels
 | 
			
		||||
                                   : steel_gemm_splitk_accum_kernels),
 | 
			
		||||
                         "name"_a = lib_name,
 | 
			
		||||
                         "atype"_a = get_type_string(in.dtype()),
 | 
			
		||||
                         "otype"_a = get_type_string(out.dtype()));
 | 
			
		||||
                  << std::vformat(
 | 
			
		||||
                         axbpy ? steel_gemm_splitk_accum_axbpy_kernels
 | 
			
		||||
                               : steel_gemm_splitk_accum_kernels,
 | 
			
		||||
                         std::make_format_args(
 | 
			
		||||
                             "name"_a = lib_name,
 | 
			
		||||
                             "atype"_a = get_type_string(in.dtype()),
 | 
			
		||||
                             "otype"_a = get_type_string(out.dtype())));
 | 
			
		||||
    return kernel_source.str();
 | 
			
		||||
  });
 | 
			
		||||
  return d.get_kernel(kernel_name, lib);
 | 
			
		||||
@@ -547,7 +545,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
 | 
			
		||||
        mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
 | 
			
		||||
    kernel_source << metal::utils() << metal::gemm()
 | 
			
		||||
                  << metal::steel_gemm_masked()
 | 
			
		||||
                  << fmt::format(
 | 
			
		||||
                  << std::format(
 | 
			
		||||
                         steel_gemm_masked_kernels,
 | 
			
		||||
                         "name"_a = lib_name,
 | 
			
		||||
                         "itype"_a = get_type_string(out.dtype()),
 | 
			
		||||
@@ -590,7 +588,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
 | 
			
		||||
    auto op_mask_type =
 | 
			
		||||
        mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
 | 
			
		||||
    kernel_source << metal::utils() << metal::gemv_masked()
 | 
			
		||||
                  << fmt::format(
 | 
			
		||||
                  << std::format(
 | 
			
		||||
                         gemv_masked_kernel,
 | 
			
		||||
                         "name"_a = lib_name,
 | 
			
		||||
                         "itype"_a = get_type_string(out.dtype()),
 | 
			
		||||
@@ -624,7 +622,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
 | 
			
		||||
  auto lib = d.get_library(lib_name, [&]() {
 | 
			
		||||
    std::ostringstream kernel_source;
 | 
			
		||||
    kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
 | 
			
		||||
                  << fmt::format(
 | 
			
		||||
                  << std::format(
 | 
			
		||||
                         steel_conv_kernels,
 | 
			
		||||
                         "name"_a = lib_name,
 | 
			
		||||
                         "itype"_a = get_type_string(out.dtype()),
 | 
			
		||||
@@ -654,7 +652,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
 | 
			
		||||
    std::ostringstream kernel_source;
 | 
			
		||||
    kernel_source << metal::utils() << metal::conv()
 | 
			
		||||
                  << metal::steel_conv_general()
 | 
			
		||||
                  << fmt::format(
 | 
			
		||||
                  << std::format(
 | 
			
		||||
                         steel_conv_general_kernels,
 | 
			
		||||
                         "name"_a = lib_name,
 | 
			
		||||
                         "itype"_a = get_type_string(out.dtype()),
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
// Copyright © 2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
#include <format>
 | 
			
		||||
 | 
			
		||||
#include "mlx/array.h"
 | 
			
		||||
#include "mlx/backend/metal/device.h"
 | 
			
		||||
@@ -218,7 +218,7 @@ get_template_definition(std::string name, std::string func, Args... args) {
 | 
			
		||||
  };
 | 
			
		||||
  (add_arg(args), ...);
 | 
			
		||||
  s << ">";
 | 
			
		||||
  return fmt::format(
 | 
			
		||||
  return std::format(
 | 
			
		||||
      "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
 | 
			
		||||
      name,
 | 
			
		||||
      s.str());
 | 
			
		||||
 
 | 
			
		||||
@@ -66,7 +66,7 @@ array compute_dynamic_offset(
 | 
			
		||||
  auto dtype = indices.dtype();
 | 
			
		||||
  std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
 | 
			
		||||
  auto lib = d.get_library(lib_name, [dtype]() {
 | 
			
		||||
    return fmt::format(
 | 
			
		||||
    return std::format(
 | 
			
		||||
        R"(
 | 
			
		||||
        [[kernel]] void compute_dynamic_offset_{0}(
 | 
			
		||||
            constant const {1}* indices [[buffer(0)]],
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user