mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4515866024 | ||
|
|
6fe2b82926 | ||
|
|
c75b5e9d19 | ||
|
|
6f12eda549 | ||
|
|
a541fe9312 | ||
|
|
2bdd20f257 | ||
|
|
aa7b9688ce | ||
|
|
0a41393dba | ||
|
|
e300a01f4a |
@@ -24,7 +24,7 @@ jobs:
|
|||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
macos:
|
macos:
|
||||||
xcode: "15.2.0"
|
xcode: "16.0.0"
|
||||||
resource_class: macos.m1.medium.gen1
|
resource_class: macos.m1.medium.gen1
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
@@ -70,8 +70,8 @@ jobs:
|
|||||||
git push -f origin gh-pages
|
git push -f origin gh-pages
|
||||||
|
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
docker:
|
machine:
|
||||||
- image: cimg/python:3.9
|
image: ubuntu-2404:2024.11.1
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
@@ -84,30 +84,33 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
pip install --upgrade cmake
|
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
pip install nanobind==2.4.0
|
sudo apt-get update -y
|
||||||
pip install numpy
|
sudo apt-get install -y python3.9 python3.9-distutils python3.9-dev
|
||||||
|
python3.9 -m pip install --upgrade cmake
|
||||||
|
python3.9 -m pip install nanobind==2.4.0
|
||||||
|
python3.9 -m pip install numpy
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libopenblas-dev liblapacke-dev openmpi-bin libopenmpi-dev
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF -DPython_EXECUTABLE=/usr/bin/python3.9" \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
python3 setup.py build_ext --inplace
|
python3.9 setup.py build_ext --inplace
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF -DPython_EXECUTABLE=/usr/bin/python3.9" \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
python3 setup.py develop
|
python3.9 setup.py develop
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
echo "stubs"
|
||||||
pip install typing_extensions
|
python3.9 -m pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python3.9 setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests -v
|
python3.9 -m unittest discover python/tests -v
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
@@ -122,7 +125,10 @@ jobs:
|
|||||||
parameters:
|
parameters:
|
||||||
xcode_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "15.2.0"
|
default: "16.0.0"
|
||||||
|
deployment_target:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
macos:
|
macos:
|
||||||
xcode: << parameters.xcode_version >>
|
xcode: << parameters.xcode_version >>
|
||||||
resource_class: macos.m1.medium.gen1
|
resource_class: macos.m1.medium.gen1
|
||||||
@@ -146,7 +152,9 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
|
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
|
||||||
|
pip install -e . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
@@ -173,7 +181,11 @@ jobs:
|
|||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
mkdir -p build
|
||||||
|
cd build/
|
||||||
|
cmake .. \
|
||||||
|
-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>
|
||||||
|
make -j `sysctl -n hw.ncpu`
|
||||||
- run:
|
- run:
|
||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
command: |
|
command: |
|
||||||
@@ -188,14 +200,15 @@ jobs:
|
|||||||
-DMLX_BUILD_CPU=OFF \
|
-DMLX_BUILD_CPU=OFF \
|
||||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||||
-DMLX_BUILD_GGUF=OFF \
|
-DMLX_BUILD_GGUF=OFF \
|
||||||
-DMLX_METAL_JIT=ON
|
-DMLX_METAL_JIT=ON \
|
||||||
|
-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>
|
||||||
make -j `sysctl -n hw.ncpu`
|
make -j `sysctl -n hw.ncpu`
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests with JIT
|
name: Run Python tests with JIT
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON -DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>" \
|
||||||
pip install -e . -v
|
pip install -e . -v
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
@@ -208,7 +221,10 @@ jobs:
|
|||||||
default: "3.9"
|
default: "3.9"
|
||||||
xcode_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "15.2.0"
|
default: "16.0.0"
|
||||||
|
deployment_target:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
build_env:
|
build_env:
|
||||||
type: string
|
type: string
|
||||||
default: ""
|
default: ""
|
||||||
@@ -237,6 +253,7 @@ jobs:
|
|||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
DEV_RELEASE=1 \
|
DEV_RELEASE=1 \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
|
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
@@ -250,6 +267,7 @@ jobs:
|
|||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
<< parameters.build_env >> \
|
<< parameters.build_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
|
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
|
||||||
python -m build -w
|
python -m build -w
|
||||||
- when:
|
- when:
|
||||||
condition: << parameters.build_env >>
|
condition: << parameters.build_env >>
|
||||||
@@ -330,9 +348,10 @@ workflows:
|
|||||||
- mac_build_and_test:
|
- mac_build_and_test:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
xcode_version: ["16.0.0"]
|
||||||
|
deployment_target: ["", "13.5"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
when:
|
when:
|
||||||
@@ -350,7 +369,8 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
xcode_version: ["16.0.0"]
|
||||||
|
deployment_target: ["", "13.5"]
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
- build_documentation:
|
- build_documentation:
|
||||||
filters:
|
filters:
|
||||||
@@ -374,7 +394,8 @@ workflows:
|
|||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
xcode_version: ["16.0.0"]
|
||||||
|
deployment_target: ["", "13.5"]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
@@ -387,7 +408,8 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
xcode_version: ["16.0.0"]
|
||||||
|
deployment_target: ["", "13.5"]
|
||||||
weekly_build:
|
weekly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
@@ -398,7 +420,8 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
xcode_version: ["16.0.0"]
|
||||||
|
deployment_target: ["", "13.5"]
|
||||||
build_env: ["DEV_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
linux_test_release:
|
linux_test_release:
|
||||||
when:
|
when:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ project(mlx LANGUAGES C CXX)
|
|||||||
|
|
||||||
# ----------------------------- Setup -----------------------------
|
# ----------------------------- Setup -----------------------------
|
||||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||||
@@ -223,14 +223,6 @@ target_include_directories(
|
|||||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
$<INSTALL_INTERFACE:include>)
|
$<INSTALL_INTERFACE:include>)
|
||||||
|
|
||||||
FetchContent_Declare(
|
|
||||||
fmt
|
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
|
||||||
GIT_TAG 10.2.1
|
|
||||||
EXCLUDE_FROM_ALL)
|
|
||||||
FetchContent_MakeAvailable(fmt)
|
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(
|
find_package(
|
||||||
|
|||||||
@@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
#include <dlfcn.h>
|
#include <dlfcn.h>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <format>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <shared_mutex>
|
#include <shared_mutex>
|
||||||
|
|
||||||
#include <fmt/format.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/compiled_preamble.h"
|
#include "mlx/backend/common/compiled_preamble.h"
|
||||||
#include "mlx/backend/common/jit_compiler.h"
|
#include "mlx/backend/common/jit_compiler.h"
|
||||||
@@ -111,7 +110,7 @@ void* compile(
|
|||||||
JitCompiler::exec(JitCompiler::build_command(
|
JitCompiler::exec(JitCompiler::build_command(
|
||||||
output_dir, source_file_name, shared_lib_name));
|
output_dir, source_file_name, shared_lib_name));
|
||||||
} catch (const std::exception& error) {
|
} catch (const std::exception& error) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(std::format(
|
||||||
"[Compile::eval_cpu] Failed to compile function {0}: {1}",
|
"[Compile::eval_cpu] Failed to compile function {0}: {1}",
|
||||||
kernel_name,
|
kernel_name,
|
||||||
error.what()));
|
error.what()));
|
||||||
|
|||||||
@@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/jit_compiler.h"
|
#include "mlx/backend/common/jit_compiler.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <format>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ struct VisualStudioInfo {
|
|||||||
arch = "x64";
|
arch = "x64";
|
||||||
#endif
|
#endif
|
||||||
// Get path of Visual Studio.
|
// Get path of Visual Studio.
|
||||||
std::string vs_path = JitCompiler::exec(fmt::format(
|
std::string vs_path = JitCompiler::exec(std::format(
|
||||||
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
|
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
|
||||||
" -property installationPath",
|
" -property installationPath",
|
||||||
std::getenv("ProgramFiles(x86)")));
|
std::getenv("ProgramFiles(x86)")));
|
||||||
@@ -41,7 +42,7 @@ struct VisualStudioInfo {
|
|||||||
throw std::runtime_error("Can not find Visual Studio.");
|
throw std::runtime_error("Can not find Visual Studio.");
|
||||||
}
|
}
|
||||||
// Read the envs from vcvarsall.
|
// Read the envs from vcvarsall.
|
||||||
std::string envs = JitCompiler::exec(fmt::format(
|
std::string envs = JitCompiler::exec(std::format(
|
||||||
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
|
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
|
||||||
vs_path,
|
vs_path,
|
||||||
arch));
|
arch));
|
||||||
@@ -55,7 +56,7 @@ struct VisualStudioInfo {
|
|||||||
if (name == "LIB") {
|
if (name == "LIB") {
|
||||||
libpaths = str_split(value, ';');
|
libpaths = str_split(value, ';');
|
||||||
} else if (name == "VCToolsInstallDir") {
|
} else if (name == "VCToolsInstallDir") {
|
||||||
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
|
cl_exe = std::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,9 +82,9 @@ std::string JitCompiler::build_command(
|
|||||||
const VisualStudioInfo& info = GetVisualStudioInfo();
|
const VisualStudioInfo& info = GetVisualStudioInfo();
|
||||||
std::string libpaths;
|
std::string libpaths;
|
||||||
for (const std::string& lib : info.libpaths) {
|
for (const std::string& lib : info.libpaths) {
|
||||||
libpaths += fmt::format(" /libpath:\"{0}\"", lib);
|
libpaths += std::format(" /libpath:\"{0}\"", lib);
|
||||||
}
|
}
|
||||||
return fmt::format(
|
return std::format(
|
||||||
"\""
|
"\""
|
||||||
"cd /D \"{0}\" && "
|
"cd /D \"{0}\" && "
|
||||||
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
|
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
|
||||||
@@ -95,7 +96,7 @@ std::string JitCompiler::build_command(
|
|||||||
shared_lib_name,
|
shared_lib_name,
|
||||||
libpaths);
|
libpaths);
|
||||||
#else
|
#else
|
||||||
return fmt::format(
|
return std::format(
|
||||||
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1",
|
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1",
|
||||||
(dir / source_file_name).string(),
|
(dir / source_file_name).string(),
|
||||||
(dir / shared_lib_name).string());
|
(dir / shared_lib_name).string());
|
||||||
@@ -139,7 +140,7 @@ std::string JitCompiler::exec(const std::string& cmd) {
|
|||||||
int code = WEXITSTATUS(status);
|
int code = WEXITSTATUS(status);
|
||||||
#endif
|
#endif
|
||||||
if (code != 0) {
|
if (code != 0) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(std::format(
|
||||||
"Failed to execute command with return code {0}: \"{1}\", "
|
"Failed to execute command with return code {0}: \"{1}\", "
|
||||||
"the output is: {2}",
|
"the output is: {2}",
|
||||||
code,
|
code,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <fmt/format.h>
|
#include <format>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
@@ -11,8 +11,6 @@
|
|||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
using namespace fmt::literals;
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
inline void build_kernel(
|
inline void build_kernel(
|
||||||
@@ -41,7 +39,7 @@ inline void build_kernel(
|
|||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
|
||||||
// Start the kernel
|
// Start the kernel
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
@@ -57,7 +55,7 @@ inline void build_kernel(
|
|||||||
if (!is_scalar(x) && !contiguous) {
|
if (!is_scalar(x) && !contiguous) {
|
||||||
add_indices = true;
|
add_indices = true;
|
||||||
}
|
}
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" device const {0}* {1} [[buffer({2})]],\n",
|
" device const {0}* {1} [[buffer({2})]],\n",
|
||||||
get_type_string(x.dtype()),
|
get_type_string(x.dtype()),
|
||||||
xname,
|
xname,
|
||||||
@@ -65,13 +63,13 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (add_indices) {
|
if (add_indices) {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the output arguments
|
// Add the output arguments
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" device {0}* {1} [[buffer({2})]],\n",
|
" device {0}* {1} [[buffer({2})]],\n",
|
||||||
get_type_string(x.dtype()),
|
get_type_string(x.dtype()),
|
||||||
namer.get_name(x),
|
namer.get_name(x),
|
||||||
@@ -79,13 +77,13 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
// Add output strides and shape to extract the indices.
|
// Add output strides and shape to extract the indices.
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||||
}
|
}
|
||||||
if (dynamic_dims) {
|
if (dynamic_dims) {
|
||||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
os += std::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The thread index in the whole grid
|
// The thread index in the whole grid
|
||||||
@@ -98,15 +96,15 @@ inline void build_kernel(
|
|||||||
// a third grid dimension
|
// a third grid dimension
|
||||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
||||||
} else if (work_per_thread > 1) {
|
} else if (work_per_thread > 1) {
|
||||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
os += std::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" int xshape = output_shape[{0}];\n",
|
" int xshape = output_shape[{0}];\n",
|
||||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
|
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||||
idx_type);
|
idx_type);
|
||||||
} else {
|
} else {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||||
idx_type);
|
idx_type);
|
||||||
}
|
}
|
||||||
@@ -121,16 +119,16 @@ inline void build_kernel(
|
|||||||
auto type_str = get_type_string(x.dtype());
|
auto type_str = get_type_string(x.dtype());
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
print_constant(ss, x);
|
print_constant(ss, x);
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" auto tmp_{0} = static_cast<{1}>({2});\n",
|
" auto tmp_{0} = static_cast<{1}>({2});\n",
|
||||||
xname,
|
xname,
|
||||||
get_type_string(x.dtype()),
|
get_type_string(x.dtype()),
|
||||||
ss.str());
|
ss.str());
|
||||||
} else if (is_scalar(x)) {
|
} else if (is_scalar(x)) {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
|
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
|
||||||
} else if (contiguous) {
|
} else if (contiguous) {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
|
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
|
||||||
} else {
|
} else {
|
||||||
nc_inputs.push_back(x);
|
nc_inputs.push_back(x);
|
||||||
@@ -140,30 +138,30 @@ inline void build_kernel(
|
|||||||
// Initialize the indices for non-contiguous inputs
|
// Initialize the indices for non-contiguous inputs
|
||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(nc_inputs[i]);
|
auto& xname = namer.get_name(nc_inputs[i]);
|
||||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
os += std::format(" {0} index_{1} = ", idx_type, xname);
|
||||||
if (ndim == 1) {
|
if (ndim == 1) {
|
||||||
int offset = i * ndim;
|
int offset = i * ndim;
|
||||||
os +=
|
os +=
|
||||||
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
|
std::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||||
} else if (ndim == 2) {
|
} else if (ndim == 2) {
|
||||||
int offset = i * ndim;
|
int offset = i * ndim;
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||||
idx_type,
|
idx_type,
|
||||||
offset);
|
offset);
|
||||||
} else if (ndim == 3) {
|
} else if (ndim == 3) {
|
||||||
int offset = i * ndim;
|
int offset = i * ndim;
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
|
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
|
||||||
} else if (!dynamic_dims) {
|
} else if (!dynamic_dims) {
|
||||||
int offset = (i + 1) * ndim;
|
int offset = (i + 1) * ndim;
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
|
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
|
||||||
idx_type,
|
idx_type,
|
||||||
offset - 1,
|
offset - 1,
|
||||||
offset - 2);
|
offset - 2);
|
||||||
} else {
|
} else {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
|
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
|
||||||
idx_type,
|
idx_type,
|
||||||
i);
|
i);
|
||||||
@@ -175,18 +173,18 @@ inline void build_kernel(
|
|||||||
if (dynamic_dims) {
|
if (dynamic_dims) {
|
||||||
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
|
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||||
} else {
|
} else {
|
||||||
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
os += std::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||||
}
|
}
|
||||||
os += " uint l = zpos % output_shape[d];\n";
|
os += " uint l = zpos % output_shape[d];\n";
|
||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(nc_inputs[i]);
|
auto& xname = namer.get_name(nc_inputs[i]);
|
||||||
os += fmt::format(" index_{0} += ", xname);
|
os += std::format(" index_{0} += ", xname);
|
||||||
if (dynamic_dims) {
|
if (dynamic_dims) {
|
||||||
os +=
|
os +=
|
||||||
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
|
std::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
|
||||||
} else {
|
} else {
|
||||||
os +=
|
os +=
|
||||||
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
|
std::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
os += " zpos /= output_shape[d];\n }\n";
|
os += " zpos /= output_shape[d];\n }\n";
|
||||||
@@ -202,16 +200,16 @@ inline void build_kernel(
|
|||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
auto& x = nc_inputs[i];
|
auto& x = nc_inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
|
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Actually write the computation
|
// Actually write the computation
|
||||||
for (auto& x : tape) {
|
for (auto& x : tape) {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
|
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
|
||||||
if (is_static_cast(x.primitive())) {
|
if (is_static_cast(x.primitive())) {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
"static_cast<{0}>(tmp_{1});\n",
|
"static_cast<{0}>(tmp_{1});\n",
|
||||||
get_type_string(x.dtype()),
|
get_type_string(x.dtype()),
|
||||||
namer.get_name(x.inputs()[0]));
|
namer.get_name(x.inputs()[0]));
|
||||||
@@ -221,15 +219,15 @@ inline void build_kernel(
|
|||||||
os += ss.str();
|
os += ss.str();
|
||||||
os += "()(";
|
os += "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
os += std::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||||
}
|
}
|
||||||
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
os += std::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write the outputs from tmps
|
// Write the outputs from tmps
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
os += std::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||||
}
|
}
|
||||||
// Increment indices and close per thread loop
|
// Increment indices and close per thread loop
|
||||||
if (work_per_thread > 1) {
|
if (work_per_thread > 1) {
|
||||||
@@ -237,10 +235,10 @@ inline void build_kernel(
|
|||||||
auto& x = nc_inputs[i];
|
auto& x = nc_inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
if (!dynamic_dims) {
|
if (!dynamic_dims) {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
|
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
|
||||||
} else {
|
} else {
|
||||||
os += fmt::format(
|
os += std::format(
|
||||||
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
|
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <fmt/format.h>
|
#include <format>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
@@ -20,9 +20,9 @@ std::pair<std::string, std::string> make_index_args(
|
|||||||
std::ostringstream idx_args;
|
std::ostringstream idx_args;
|
||||||
std::ostringstream idx_arr;
|
std::ostringstream idx_arr;
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
idx_args << fmt::format(
|
idx_args << std::format(
|
||||||
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
|
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
|
||||||
idx_arr << fmt::format("idx{0}", i);
|
idx_arr << std::format("idx{0}", i);
|
||||||
if (i < nidx - 1) {
|
if (i < nidx - 1) {
|
||||||
idx_args << "\n";
|
idx_args << "\n";
|
||||||
idx_arr << ",";
|
idx_arr << ",";
|
||||||
@@ -59,7 +59,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
bool large = large_index || large_src || large_out;
|
bool large = large_index || large_src || large_out;
|
||||||
|
|
||||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||||
std::string kernel_name = fmt::format(
|
std::string kernel_name = std::format(
|
||||||
"gather{0}{1}_{2}_{3}_{4}",
|
"gather{0}{1}_{2}_{3}_{4}",
|
||||||
type_to_name(out),
|
type_to_name(out),
|
||||||
idx_type_name,
|
idx_type_name,
|
||||||
@@ -77,7 +77,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||||
|
|
||||||
// Index dimension specializations
|
// Index dimension specializations
|
||||||
kernel_source += fmt::format(
|
kernel_source += std::format(
|
||||||
gather_kernels,
|
gather_kernels,
|
||||||
type_to_name(out) + idx_type_name,
|
type_to_name(out) + idx_type_name,
|
||||||
out_type_str,
|
out_type_str,
|
||||||
@@ -238,7 +238,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
|
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
|
||||||
bool large_upd = upd.size() > INT32_MAX;
|
bool large_upd = upd.size() > INT32_MAX;
|
||||||
bool large = large_out || large_idx || large_upd;
|
bool large = large_out || large_idx || large_upd;
|
||||||
std::string kernel_name = fmt::format(
|
std::string kernel_name = std::format(
|
||||||
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
|
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
|
||||||
type_to_name(out),
|
type_to_name(out),
|
||||||
idx_type_name,
|
idx_type_name,
|
||||||
@@ -275,11 +275,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (reduce_type_ != Scatter::None) {
|
if (reduce_type_ != Scatter::None) {
|
||||||
op_type = fmt::format(fmt::runtime(op_type), out_type_str);
|
op_type = std::vformat(op_type, std::make_format_args(out_type_str));
|
||||||
}
|
}
|
||||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||||
|
|
||||||
kernel_source += fmt::format(
|
kernel_source += std::format(
|
||||||
scatter_kernels,
|
scatter_kernels,
|
||||||
type_to_name(out) + idx_type_name + "_" + op_name,
|
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||||
out_type_str,
|
out_type_str,
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
constexpr std::string_view gemv_masked_kernel = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
|
|
||||||
const device {itype}* mat [[buffer(0)]],
|
|
||||||
const device {itype}* in_vec [[buffer(1)]],
|
|
||||||
device {itype}* out_vec [[buffer(3)]],
|
|
||||||
const constant int& in_vec_size [[buffer(4)]],
|
|
||||||
const constant int& out_vec_size [[buffer(5)]],
|
|
||||||
const constant int& marix_ld [[buffer(6)]],
|
|
||||||
const constant int& batch_ndim [[buffer(9)]],
|
|
||||||
const constant int* batch_shape [[buffer(10)]],
|
|
||||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
|
||||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
|
||||||
const device {outm_t}* out_mask [[buffer(20)]],
|
|
||||||
const device {opm_t}* mat_mask [[buffer(21)]],
|
|
||||||
const device {opm_t}* vec_mask [[buffer(22)]],
|
|
||||||
const constant int* mask_strides [[buffer(23)]],
|
|
||||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
|
||||||
)";
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
constexpr std::string_view steel_conv_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
|
|
||||||
const device {itype}* A [[buffer(0)]],
|
|
||||||
const device {itype}* B [[buffer(1)]],
|
|
||||||
device {itype}* C [[buffer(2)]],
|
|
||||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
|
||||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_conv_general_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
|
|
||||||
const device {itype}* A [[buffer(0)]],
|
|
||||||
const device {itype}* B [[buffer(1)]],
|
|
||||||
device {itype}* C [[buffer(2)]],
|
|
||||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
|
||||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
|
||||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
|
||||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
|
||||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
|
||||||
)";
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_fused_kernels = R"(
|
|
||||||
template [[host_name("{name}")]]
|
|
||||||
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
|
|
||||||
const device {itype} *A [[buffer(0)]],
|
|
||||||
const device {itype} *B [[buffer(1)]],
|
|
||||||
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
|
|
||||||
device {itype} *D [[buffer(3)]],
|
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
|
||||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
|
||||||
const constant int64_t* batch_strides [[buffer(7)]],
|
|
||||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
|
||||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
|
||||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
|
||||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
|
||||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
|
||||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_masked_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
block_masked_gemm<
|
|
||||||
{itype},
|
|
||||||
{outmasktype},
|
|
||||||
{opmasktype},
|
|
||||||
{bm},
|
|
||||||
{bn},
|
|
||||||
{bk},
|
|
||||||
{wm},
|
|
||||||
{wn},
|
|
||||||
{trans_a},
|
|
||||||
{trans_b},
|
|
||||||
{mn_aligned},
|
|
||||||
{k_aligned}>(
|
|
||||||
const device {itype}* A [[buffer(0)]],
|
|
||||||
const device {itype}* B [[buffer(1)]],
|
|
||||||
device {itype}* D [[buffer(3)]],
|
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
|
||||||
const constant int64_t* batch_strides [[buffer(7)]],
|
|
||||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
|
||||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
|
||||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
|
||||||
const constant int* mask_strides [[buffer(13)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_splitk_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
gemm_splitk<
|
|
||||||
{itype},
|
|
||||||
{otype},
|
|
||||||
{bm},
|
|
||||||
{bn},
|
|
||||||
{bk},
|
|
||||||
{wm},
|
|
||||||
{wn},
|
|
||||||
{trans_a},
|
|
||||||
{trans_b},
|
|
||||||
{mn_aligned},
|
|
||||||
{k_aligned}>(
|
|
||||||
const device {itype}* A [[buffer(0)]],
|
|
||||||
const device {itype}* B [[buffer(1)]],
|
|
||||||
device {otype}* C [[buffer(2)]],
|
|
||||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
gemm_splitk_accum<{atype}, {otype}>(
|
|
||||||
const device {atype}* C_split [[buffer(0)]],
|
|
||||||
device {otype}* D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
gemm_splitk_accum_axpby<{atype}, {otype}>(
|
|
||||||
const device {atype}* C_split [[buffer(0)]],
|
|
||||||
device {otype}* D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
const device {otype}* C [[buffer(5)]],
|
|
||||||
const constant int& ldc [[buffer(6)]],
|
|
||||||
const constant int& fdc [[buffer(7)]],
|
|
||||||
const constant float& alpha [[buffer(8)]],
|
|
||||||
const constant float& beta [[buffer(9)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
)";
|
|
||||||
@@ -1,16 +1,11 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/metal/jit/arange.h"
|
#include "mlx/backend/metal/jit/arange.h"
|
||||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
#include "mlx/backend/metal/jit/softmax.h"
|
#include "mlx/backend/metal/jit/softmax.h"
|
||||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
|
||||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
|
||||||
using namespace fmt::literals;
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string op_name(const array& arr) {
|
std::string op_name(const array& arr) {
|
||||||
@@ -26,7 +21,7 @@ MTL::ComputePipelineState* get_arange_kernel(
|
|||||||
auto lib = d.get_library(kernel_name, [&]() {
|
auto lib = d.get_library(kernel_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::arange()
|
kernel_source << metal::utils() << metal::arange()
|
||||||
<< fmt::format(
|
<< std::format(
|
||||||
arange_kernels,
|
arange_kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
get_type_string(out.dtype()));
|
get_type_string(out.dtype()));
|
||||||
@@ -259,7 +254,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
|||||||
auto lib = d.get_library(lib_name, [&] {
|
auto lib = d.get_library(lib_name, [&] {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::softmax()
|
kernel_source << metal::utils() << metal::softmax()
|
||||||
<< fmt::format(
|
<< std::format(
|
||||||
softmax_kernels,
|
softmax_kernels,
|
||||||
lib_name,
|
lib_name,
|
||||||
get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
@@ -445,17 +440,17 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_fused()
|
<< metal::steel_gemm_fused()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_gemm_fused_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"gemm",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn,
|
wn,
|
||||||
"trans_a"_a = transpose_a,
|
transpose_a,
|
||||||
"trans_b"_a = transpose_b);
|
transpose_b);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
@@ -480,20 +475,20 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_splitk()
|
<< metal::steel_gemm_splitk()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_gemm_splitk_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"gemm_splitk",
|
||||||
"itype"_a = get_type_string(in.dtype()),
|
get_type_string(in.dtype()),
|
||||||
"otype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn,
|
wn,
|
||||||
"trans_a"_a = transpose_a,
|
transpose_a,
|
||||||
"trans_b"_a = transpose_b,
|
transpose_b,
|
||||||
"mn_aligned"_a = mn_aligned,
|
mn_aligned,
|
||||||
"k_aligned"_a = k_aligned);
|
k_aligned);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@@ -510,13 +505,12 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_splitk()
|
<< metal::steel_gemm_splitk()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
fmt::runtime(
|
lib_name,
|
||||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
axbpy ? "gemm_splitk_accum_axpby"
|
||||||
: steel_gemm_splitk_accum_kernels),
|
: "gemm_splitk_accum",
|
||||||
"name"_a = lib_name,
|
get_type_string(in.dtype()),
|
||||||
"atype"_a = get_type_string(in.dtype()),
|
get_type_string(out.dtype()));
|
||||||
"otype"_a = get_type_string(out.dtype()));
|
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@@ -547,21 +541,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_masked()
|
<< metal::steel_gemm_masked()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_gemm_masked_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"block_masked_gemm",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"outmasktype"_a = out_mask_type,
|
out_mask_type,
|
||||||
"opmasktype"_a = op_mask_type,
|
op_mask_type,
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn,
|
wn,
|
||||||
"trans_a"_a = transpose_a,
|
transpose_a,
|
||||||
"trans_b"_a = transpose_b,
|
transpose_b,
|
||||||
"mn_aligned"_a = mn_aligned,
|
mn_aligned,
|
||||||
"k_aligned"_a = k_aligned);
|
k_aligned);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@@ -590,20 +584,19 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
|||||||
auto op_mask_type =
|
auto op_mask_type =
|
||||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||||
kernel_source << metal::utils() << metal::gemv_masked()
|
kernel_source << metal::utils() << metal::gemv_masked()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
gemv_masked_kernel,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
(transpose_mat) ? "gemv_t_masked" : "gemv_masked",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"outm_t"_a = out_mask_type,
|
out_mask_type,
|
||||||
"opm_t"_a = op_mask_type,
|
op_mask_type,
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"sm"_a = sm,
|
sm,
|
||||||
"sn"_a = sn,
|
sn,
|
||||||
"tm"_a = tm,
|
tm,
|
||||||
"tn"_a = tn,
|
tn,
|
||||||
"trans"_a = transpose_mat ? "t_" : "",
|
contiguous ? 0 : 1);
|
||||||
"nc"_a = contiguous ? "0" : "1");
|
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@@ -624,17 +617,17 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
|||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_conv_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"implicit_gemm_conv_2d",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn,
|
wn,
|
||||||
"n_channels"_a = n_channel_specialization,
|
n_channel_specialization,
|
||||||
"small_filter"_a = small_filter);
|
small_filter);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@@ -654,15 +647,15 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::conv()
|
kernel_source << metal::utils() << metal::conv()
|
||||||
<< metal::steel_conv_general()
|
<< metal::steel_conv_general()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_conv_general_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"implicit_gemm_conv_2d_general",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn);
|
wn);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <format>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
@@ -218,7 +218,7 @@ get_template_definition(std::string name, std::string func, Args... args) {
|
|||||||
};
|
};
|
||||||
(add_arg(args), ...);
|
(add_arg(args), ...);
|
||||||
s << ">";
|
s << ">";
|
||||||
return fmt::format(
|
return std::format(
|
||||||
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
|
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
|
||||||
name,
|
name,
|
||||||
s.str());
|
s.str());
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ array compute_dynamic_offset(
|
|||||||
auto dtype = indices.dtype();
|
auto dtype = indices.dtype();
|
||||||
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
|
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
|
||||||
auto lib = d.get_library(lib_name, [dtype]() {
|
auto lib = d.get_library(lib_name, [dtype]() {
|
||||||
return fmt::format(
|
return std::format(
|
||||||
R"(
|
R"(
|
||||||
[[kernel]] void compute_dynamic_offset_{0}(
|
[[kernel]] void compute_dynamic_offset_{0}(
|
||||||
constant const {1}* indices [[buffer(0)]],
|
constant const {1}* indices [[buffer(0)]],
|
||||||
|
|||||||
Reference in New Issue
Block a user