Compare commits

..

82 Commits

Author SHA1 Message Date
Angelos Katharopoulos
4515866024 Change the linux test to ubuntu 24.04 2025-01-20 22:58:05 -08:00
Angelos Katharopoulos
6fe2b82926 Add gcc 13 to the linux build 2025-01-17 18:00:24 -08:00
Angelos Katharopoulos
c75b5e9d19 Do not build wheels every time 2025-01-17 17:42:11 -08:00
Angelos Katharopoulos
6f12eda549 Test for 13.5 as well 2025-01-17 17:39:20 -08:00
Angelos Katharopoulos
a541fe9312 Set deployment target 13.5 2025-01-17 17:16:36 -08:00
Angelos Katharopoulos
2bdd20f257 Test XCode 16 2025-01-17 14:14:45 -08:00
Angelos Katharopoulos
aa7b9688ce Move some kernels to get_template_definition 2025-01-17 13:24:17 -08:00
Angelos Katharopoulos
0a41393dba Replace fmt::format with std::format 2025-01-17 11:24:16 -08:00
Angelos Katharopoulos
e300a01f4a Testing C++ 20 2025-01-16 18:24:16 -08:00
Awni Hannun
f288db8d34 Fix synchronization bug for in stream async works (#1768) 2025-01-15 06:07:34 -08:00
Awni Hannun
33421c1dd3 Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth

* add grad of module test
2025-01-14 14:33:18 -08:00
Nripesh Niketan
5cc5201914 feat: Add orthogonal initializer and corresponding tests (#1651)
* feat: Add orthogonal initializer and corresponding tests

* lint

* Add acknowledgements

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-13 07:29:20 -08:00
Awni Hannun
252e423e81 fix and cleanup event signal/wait for metal (#1765) 2025-01-10 18:37:26 -08:00
wrmsr
a4a2764a52 Fix broadcast_arrays python sig (#1763) 2025-01-10 12:33:26 -08:00
Cheng
ab8e832c18 0ul is not size_t on MSVC (#1762) 2025-01-10 12:33:11 -08:00
Angelos Katharopoulos
1ce0c0fcb0 Bump version (#1761) 2025-01-09 13:48:20 -08:00
Awni Hannun
657f466402 use sdpa and exportable functions in transformer multi head attention (#1760) 2025-01-09 13:11:55 -08:00
Alex Barron
c7b0300af5 Fix batched qmv bug (#1758) 2025-01-09 11:45:57 -08:00
Awni Hannun
da8c885784 Simplify removes no-ops from the tape (#1759)
* simplify removes no-ops from the tape

* comment
2025-01-09 11:23:19 -08:00
Awni Hannun
1ccaf80575 Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
2025-01-09 11:04:24 -08:00
Cheng
ec36bfa317 Include command stdout in error message (#1756)
* Include command stdout in error message

* On Windows pclose returns the exit code
2025-01-08 07:17:03 -08:00
Cheng
b8f76f717a Print exceptions in eval_cpu/eval_gpu and abort (#1754) 2025-01-08 06:31:09 -08:00
Awni Hannun
d1766f2c70 Add boolean mask support in vector SDPA (#1757) 2025-01-07 20:24:53 -08:00
Awni Hannun
516ded618b Dynamic slicing (#1741)
* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
2025-01-07 14:02:16 -08:00
Jesper Stemann Andersen
c9c81d0584 Added additional missing unordered_map include that fixes build on FreeBSD (#1755) 2025-01-07 08:27:55 -08:00
Angelos Katharopoulos
545f84d905 Refactor distributed backend (#1752) 2025-01-06 17:33:15 -08:00
Awni Hannun
d5ec172c95 Allow boolean mask in sdpa (#1753)
* allow boolean mask in sdpa

* more permissive donation in ternary
2025-01-06 16:57:07 -08:00
Angelos Katharopoulos
25b3a3e541 Optionally specify names for arrays when exporting (#1749) 2025-01-06 13:07:46 -08:00
Awni Hannun
058d6ce683 mpi send use input as output (#1750)
* mpi send use input as output

* move earlier
2025-01-06 06:08:43 -08:00
Angelos Katharopoulos
eab93985b8 Update custom function docs (#1748) 2025-01-03 16:35:25 -08:00
Awni Hannun
b51d70a83c export docs (#1747) 2025-01-03 15:04:17 -08:00
Awni Hannun
259025100e Fix nd ternary on GPU (#1746) 2025-01-03 11:52:17 -08:00
Awni Hannun
c9d30aa6ac MLX in C++ example (#1736)
* MLX in C++ example

* nits

* fix docs
2025-01-02 19:09:04 -08:00
Angelos Katharopoulos
8544b42007 Add namespace (#1745) 2025-01-02 16:49:23 -08:00
Awni Hannun
6fa0501387 Fix concatenate/slice_update vjp + reduce binary size (#1735)
* fix concatenate vjp + reduce binary size

* also cast in slice update
2025-01-02 16:36:33 -08:00
Awni Hannun
ae69cb15e9 shapeless compile in docs and partially shapeless reshape (#1742) 2025-01-02 16:24:42 -08:00
Awni Hannun
a64a8dfe45 fix extension (#1740) 2025-01-02 16:16:16 -08:00
Venkata Naga Aditya Datta Chivukula
491fa95b1f Added Kronecker Product (#1728) 2025-01-02 16:00:34 -08:00
Danilo Peixoto
92ec632ad5 Fix Distributed Communication documentation (#1731)
* Add missing `size()` method call for group
2025-01-02 14:08:38 -08:00
Cheng
8ecdfb718b Fix export.cpp compilation with MSVC (#1737) 2024-12-29 06:56:30 -08:00
Awni Hannun
4ba0c24a8f Export / import functions to / from a file (#1642)
* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
2024-12-24 11:19:13 -08:00
Cheng
935c8c4bb1 Make mx.compile work on Windows (#1697)
* Invoke MSVC on Windows in mx.compile

* Export kernel symbol on MSVC

* Remove unused template

* Parse env pairs in a robust way

* No need of cassert

* Remove unnecessary helpers

* Fix right trim

* Move command building to a separate file

* Missing header

* Do not pollute cwd with cl.exe

* Simplify str concat

* Pass output dir

* Fix styling
2024-12-24 07:02:33 -08:00
Valentin Roussellet
88f993da38 Explicit parentheses around some logical operators (#1732)
* fix some warnings

* format
2024-12-24 07:02:20 -08:00
Awni Hannun
ebfe64b92d shapeless slice update and broadcast when possible (#1727) 2024-12-23 11:25:15 -08:00
Awni Hannun
0308e9af71 Allow offset to be an mx.array for mx.fast.rope (#1724)
* allow offset for rope

* comment
2024-12-19 15:51:44 -08:00
Awni Hannun
c3628eea49 Add mx.finfo and use it when making causal mask (#1726)
* finfo

* fixes

* docs
2024-12-19 14:52:41 -08:00
Awni Hannun
e03f0372b1 More shape type (#1705)
* more shape type

* fix
2024-12-19 08:08:20 -08:00
Alex Barron
f17536af9c More lenient mask type check in SDPA (#1723)
* check mask type

* require promotion
2024-12-18 19:41:38 -08:00
Cheng
ed4ec81bca Link python extension with mlx statically on Windows (#1716)
* Link python extension with mlx statically on Windows

* More readable code
2024-12-18 19:26:04 -08:00
Awni Hannun
7480059306 track resource limit and throw if exceeded (#1718) 2024-12-18 18:45:58 -08:00
Awni Hannun
8bae22b0fa fix deletion of non-evaled arrays with siblings (#1714) 2024-12-18 18:45:36 -08:00
Alex Barron
49c34c4161 check mask type (#1721) 2024-12-18 14:25:18 -08:00
Awni Hannun
5548fcc96d fix synch race (#1719) 2024-12-18 12:25:16 -08:00
Cheng
070bd433ab Shorter kernel name for Windows (#1701)
* Shorter kernel name for Windows

* Only hash the clipped part
2024-12-17 18:51:38 -08:00
Cheng
c8fb54951a Define NOMINMAX before windows.h (#1715) 2024-12-17 18:51:24 -08:00
Awni Hannun
f110357aaa Bump nanobind to 2.4 + fix (#1710)
* bump nanobind to 2.4 + fix

* fix
2024-12-17 10:57:54 -08:00
Tomohiro Oga
a6b426422e add cubic to type hinting for upsample (#1709) 2024-12-17 07:30:23 -08:00
Awni Hannun
d03c01dfbc fix unflatten vjp (#1708) 2024-12-16 18:37:57 -08:00
Jesper Stemann Andersen
a82996e9fb io/load: Enabled pread implementation for mingw32 (#1706) 2024-12-16 07:20:45 -08:00
Cheng
af5a614aad Eval before cleanup so model file is unlocked (#1702) 2024-12-14 21:41:49 -08:00
Cheng
f9640e049d Install mlx.dll into the same dir with python bindings on Windows (#1690)
* Install mlx.dll into the same dir with python bindings on Windows

* Set BUILD_SHARED_LIBS for dlfcn-win32

* Update cmake requirements to 3.25

* Fix cmake style
2024-12-13 19:50:39 -08:00
Cheng
4768c61b57 Make sure gguf_ctx is closed when error happens (#1699) 2024-12-13 19:50:19 -08:00
Cheng
dfccd17ab9 Use psutil to get memory info on Windows (#1700) 2024-12-13 19:50:13 -08:00
Cheng
635117c5d4 Read/write files in binary mode (#1698) 2024-12-13 17:37:05 -08:00
Awni Hannun
50f3535693 Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places

* few more

* few more

* fix
2024-12-12 17:00:44 -08:00
Awni Hannun
9111999af3 Fix small sort with metal validation (#1695) 2024-12-12 09:21:45 -08:00
Awni Hannun
6bd28d246e Allow no copy negative strides in as_strided and slice (#1688)
* allow no copy negative strides in as_strided and slice

* fix jit

* fix jit
2024-12-12 08:59:45 -08:00
Cheng
4d595a2a39 Make compiled preamble work in MSVC (#1675)
* Make compiled preamble work in MSVC

* Remove logging

* Only use powershell for MSVC
2024-12-12 08:55:49 -08:00
Awni Hannun
3a21f61772 Fix build (#1693) 2024-12-11 23:56:25 -08:00
Awni Hannun
4e1e9520e1 Flatten and unflatten (#1692)
* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
2024-12-11 21:51:37 -08:00
Cheng
0bf19037ca Remove "using namespace mlx::core" in python/src (#1689) 2024-12-11 15:45:39 -08:00
Awni Hannun
f3dfa36a3a Fix x86 tests (#1691)
* fix x86 tests

* comment
2024-12-11 07:47:18 -08:00
Cheng
4f9b60dd53 Remove "using namespace mlx::core" in benchmarks/examples (#1685)
* Remove "using namespace mlx::core" in benchmarks/examples

* Fix building example extension

* A missing one in comment

* Fix building on M chips
2024-12-11 07:08:29 -08:00
Awni Hannun
f76a49e555 ExpandDims primitive (#1687)
* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
2024-12-10 16:39:07 -08:00
Cheng
310ad8d9db Build OpenBLAS from source code for MSVC (#1674)
* Download OpenBLAS binaries when building with MSVC

* Download dlfcn-win32

* Link with dlfcn-win32 correctly

* Build OpenBLAS from source code

* Link with openblas statically

* Link with BLAS privately
2024-12-10 16:14:44 -08:00
Cheng
56db268f47 Provide a pread implementation for MSVC (#1666) 2024-12-10 15:55:53 -08:00
Cheng
92ab6bdeb8 Fix shared library not exporting symbols on Windows (#1684)
* Fix shared library not exporting symbols on Windows

* Function name style
2024-12-10 13:59:14 -08:00
Cheng
0070e360a1 Disable MSVC warnings (#1680) 2024-12-09 19:41:14 -08:00
Amethyst Shen
9df8fed046 Metal-cpp version bump (#1668)
* Metal-cpp version bump

Apple has released the stable version of Metal-cpp for macOS 15 and iOS 18. CMakeLists.txt is updated to build with it instead of the beta one.

* Fix style with cmake-format
2024-12-09 19:40:35 -08:00
Cheng
a59fae040f Fix library output directory for MSVC (#1681) 2024-12-09 19:07:50 -08:00
Awni Hannun
29a620cab2 No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
2024-12-09 18:57:38 -08:00
Cheng
87d7a2520e Use Py_ssize_t in python bindings (#1678)
* Use Py_ssize_t in python bindings

* Args passed to std::max must be same type
2024-12-09 12:59:19 -08:00
196 changed files with 9359 additions and 4071 deletions

View File

@@ -24,7 +24,7 @@ jobs:
type: boolean
default: false
macos:
xcode: "15.2.0"
xcode: "16.0.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
@@ -70,8 +70,8 @@ jobs:
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
machine:
image: ubuntu-2404:2024.11.1
steps:
- checkout
@@ -84,30 +84,33 @@ jobs:
- run:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install numpy
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get update -y
sudo apt-get install -y python3.9 python3.9-distutils python3.9-dev
python3.9 -m pip install --upgrade cmake
python3.9 -m pip install nanobind==2.4.0
python3.9 -m pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libopenblas-dev liblapacke-dev openmpi-bin libopenmpi-dev
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF -DPython_EXECUTABLE=/usr/bin/python3.9" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
python3.9 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF -DPython_EXECUTABLE=/usr/bin/python3.9" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
python3.9 setup.py develop
- run:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
python3.9 -m pip install typing_extensions
python3.9 setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
python3.9 -m unittest discover python/tests -v
- run:
name: Build CPP only
command: |
@@ -122,7 +125,10 @@ jobs:
parameters:
xcode_version:
type: string
default: "15.2.0"
default: "16.0.0"
deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
@@ -137,7 +143,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
@@ -146,7 +152,9 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
pip install -e . -v
- run:
name: Generate package stubs
command: |
@@ -173,7 +181,11 @@ jobs:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
mkdir -p build
cd build/
cmake .. \
-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>
make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
@@ -188,14 +200,15 @@ jobs:
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
-DMLX_METAL_JIT=ON \
-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
CMAKE_ARGS="-DMLX_METAL_JIT=ON -DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
@@ -208,7 +221,10 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "15.2.0"
default: "16.0.0"
deployment_target:
type: string
default: ""
build_env:
type: string
default: ""
@@ -226,7 +242,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install twine
@@ -237,6 +253,7 @@ jobs:
source env/bin/activate
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
pip install . -v
- run:
name: Generate package stubs
@@ -250,6 +267,7 @@ jobs:
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
python -m build -w
- when:
condition: << parameters.build_env >>
@@ -291,7 +309,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
@@ -330,9 +348,10 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
- linux_build_and_test
- build_documentation
- build_documentation
build_pypi_release:
when:
@@ -350,7 +369,8 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
@@ -374,7 +394,8 @@ workflows:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -387,7 +408,8 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
weekly_build:
when:
and:
@@ -398,7 +420,8 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:

3
.gitignore vendored
View File

@@ -76,6 +76,9 @@ build/
*.out
*.app
# Debug symbols
*.pdb
# VSCode
.vscode/
.DS_Store

View File

@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.

View File

@@ -1,10 +1,10 @@
cmake_minimum_required(VERSION 3.24)
cmake_minimum_required(VERSION 3.25)
project(mlx LANGUAGES C CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
@@ -20,12 +20,14 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.21.1)
set(MLX_VERSION 0.22.0)
endif()
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
# --------------------- Processor tests -------------------------
@@ -93,8 +95,7 @@ elseif(MLX_BUILD_METAL)
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
@@ -113,16 +114,55 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
set(MLX_BUILD_GGUF OFF)
# There is no prebuilt OpenBLAS distribution for MSVC.
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
endif()
# Windows implementation of dlfcn.h APIs.
FetchContent_Declare(
dlfcn-win32
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.1
EXCLUDE_FROM_ALL)
block()
set(BUILD_SHARED_LIBS OFF)
FetchContent_MakeAvailable(dlfcn-win32)
endblock()
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
target_link_libraries(mlx PRIVATE dl)
endif()
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code.
FetchContent_Declare(
openblas
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
GIT_TAG v0.3.28
EXCLUDE_FROM_ALL)
set(BUILD_STATIC_LIBS ON) # link statically
set(NOFORTRAN ON) # msvc has no fortran compiler
FetchContent_MakeAvailable(openblas)
target_link_libraries(mlx PRIVATE openblas)
target_include_directories(
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
else()
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
@@ -140,7 +180,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
@@ -153,14 +193,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
if(WIN32)
find_package(dlfcn-win32 REQUIRED)
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
endif()
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
endif()
else()
set(MLX_BUILD_ACCELERATE OFF)
@@ -190,14 +223,6 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(
@@ -207,8 +232,7 @@ if(MLX_BUILD_PYTHON_BINDINGS)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()

View File

@@ -5,35 +5,35 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_value_and_grad() {
auto x = ones({200, 1000});
eval(x);
auto fn = [](array x) {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
for (int i = 0; i < 20; ++i) {
x = log(exp(x));
x = mx::log(mx::exp(x));
}
return sum(x);
return mx::sum(x);
};
auto grad_fn = grad(fn);
auto grad_fn = mx::grad(fn);
auto independent_value_and_grad = [&]() {
auto value = fn(x);
auto dfdx = grad_fn(x);
return std::vector<array>{value, dfdx};
return std::vector<mx::array>{value, dfdx};
};
TIME(independent_value_and_grad);
auto value_and_grad_fn = value_and_grad(fn);
auto value_and_grad_fn = mx::value_and_grad(fn);
auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<array>{value, dfdx};
return std::vector<mx::array>{value, dfdx};
};
TIME(combined_value_and_grad);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_value_and_grad();
}

View File

@@ -4,21 +4,21 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_add_op() {
std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back());
}
set_default_device(Device::cpu);
set_default_device(mx::Device::cpu);
for (auto size : sizes) {
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
std::cout << "Size " << size << std::endl;
TIMEM("cpu", add, a, b, Device::cpu);
TIMEM("gpu", add, a, b, Device::gpu);
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
}
}

View File

@@ -6,105 +6,105 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_irregular_binary_ops_1D() {
auto device = default_device();
auto device = mx::default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", add, a, b, device);
TIMEM("1D strided", mx::add, a, b, device);
}
void time_irregular_binary_ops_2D() {
auto device = default_device();
auto device = mx::default_device();
int size = 2048;
auto a = random::uniform({size, size});
auto b = random::uniform({size, size});
eval(a, b);
TIMEM("2D regular", add, a, b, device);
auto a = mx::random::uniform({size, size});
auto b = mx::random::uniform({size, size});
mx::eval(a, b);
TIMEM("2D regular", mx::add, a, b, device);
b = transpose(b);
eval(b);
TIMEM("2D transpose", add, a, b, device);
b = mx::transpose(b);
mx::eval(b);
TIMEM("2D mx::transpose", mx::add, a, b, device);
b = random::uniform({size});
eval(b);
TIMEM("2D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({size});
mx::eval(b);
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
b = reshape(b, {size, 1});
eval(b);
TIMEM("2D broadcast dim 1", add, a, b, device);
b = mx::reshape(b, {size, 1});
mx::eval(b);
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
}
void time_irregular_binary_ops_3D() {
auto device = default_device();
auto device = mx::default_device();
int d0 = 32;
int d1 = 512;
int d2 = 512;
auto a = random::uniform({d0, d1, d2});
auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", add, a, b, device);
auto a = mx::random::uniform({d0, d1, d2});
auto b = mx::random::uniform({d0, d1, d2});
TIMEM("3D regular", mx::add, a, b, device);
b = transpose(b, {0, 2, 1});
TIMEM("3D transpose", add, a, b, device);
b = mx::transpose(b, {0, 2, 1});
TIMEM("3D mx::transpose", mx::add, a, b, device);
b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", add, a, b, device);
b = mx::random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", add, a, b, device);
b = mx::random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
b = mx::random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
b = mx::random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
b = mx::random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
}
void time_irregular_binary_ops_4D() {
auto device = default_device();
auto device = mx::default_device();
std::vector<int> shape = {8, 8, 512, 512};
auto a = random::uniform(shape);
auto b = random::uniform(shape);
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
TIMEM("4D regular", add, a, b, device);
TIMEM("4D regular", mx::add, a, b, device);
b = transpose(b, {0, 1, 3, 2});
TIMEM("4D transpose", add, a, b, device);
b = mx::transpose(b, {0, 1, 3, 2});
TIMEM("4D mx::transpose", mx::add, a, b, device);
std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1;
b = random::uniform(shape);
b = mx::random::uniform(shape);
std::ostringstream msg;
msg << om << i;
TIMEM(msg.str(), add, a, b, device);
TIMEM(msg.str(), mx::add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1;
std::ostringstream msg;
msg << om << i << ", " << j;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1;
std::ostringstream msg;
msg << om << i << ", " << j << ", " << k;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
shape[k] = a.shape(k);
}
}
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
}
void time_irregular_reshape() {
auto device = default_device();
auto device = mx::default_device();
std::vector<int> shape;
auto reshape_fn = [&shape, device](const array& a) {
return reshape(a, shape, device);
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
};
int size = 64;
int d = 2 * size;
auto a = random::uniform({d, d, d});
auto a = mx::random::uniform({d, d, d});
shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a);
a = transpose(a);
a = mx::transpose(a);
shape = {8 * size, size, size};
TIMEM("3D transpose", reshape_fn, a);
TIMEM("3D mx::transpose", reshape_fn, a);
a = transpose(a, {1, 2, 0});
a = mx::transpose(a, {1, 2, 0});
shape = {8 * size, size, size};
TIMEM("3D transpose dims 1 2", reshape_fn, a);
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, d}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a);
a = broadcast_to(random::uniform({d}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
}
void time_irregular_astype_1D() {
auto device = default_device();
auto device = mx::default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
auto a = mx::random::uniform({size});
a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", astype, a, int32, device);
TIMEM("1D strided", mx::astype, a, mx::int32, device);
}
void time_irregular_astype_2D() {
auto device = default_device();
auto device = mx::default_device();
int size = 2048;
std::vector<int> shape = {size, size};
auto a = random::uniform(shape);
TIMEM("2D regular", astype, a, int32, device);
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);
a = transpose(a);
TIMEM("2D transpose", astype, a, int32, device);
a = mx::transpose(a);
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
}
int main(int argc, char** argv) {
if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? Device::gpu : Device::cpu);
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
}
std::cout << "Benchmarks for " << default_device() << std::endl;
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D();

View File

@@ -3,20 +3,20 @@
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
namespace mx = mlx::core;
void time_creation_ops() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto full_fp32 = [&]() { return full(shape, 3.3f); };
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
TIME(full_fp32);
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
TIME(zeros_fp32);
auto ones_fp32 = [&]() { return ones(shape, float32); };
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
TIME(ones_fp32);
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32);
}
@@ -24,194 +24,196 @@ void time_type_conversions() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto device = default_device();
auto device = mx::default_device();
auto a = zeros(shape, float32);
eval(a);
TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("float32 to uint32", astype, a, uint32, device);
auto a = mx::zeros(shape, mx::float32);
mx::eval(a);
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
a = zeros(shape, int32);
eval(a);
TIMEM("int32 to float32", astype, a, float32, device);
a = mx::zeros(shape, mx::int32);
mx::eval(a);
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
a = zeros(shape, bool_);
eval(a);
TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to uint32", astype, a, uint32, device);
a = mx::zeros(shape, mx::bool_);
mx::eval(a);
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
}
void time_random_generation() {
int M = 2000;
int N = 500;
auto uniform = [&]() { return random::uniform({M, N}, float32); };
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
TIME(uniform);
auto normal = [&]() { return random::normal({M, N}, float32); };
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
TIME(normal);
}
void time_unary_ops() {
int M = 2000;
int N = 500;
auto device = default_device();
auto device = mx::default_device();
auto a = random::normal({M, N});
eval(a);
auto a = mx::random::normal({M, N});
mx::eval(a);
TIME(mlx::core::abs, a, device);
TIME(negative, a, device);
TIME(sign, a, device);
TIME(square, a, device);
TIME(mx::negative, a, device);
TIME(mx::sign, a, device);
TIME(mx::square, a, device);
TIME(mlx::core::sqrt, a, device);
TIME(rsqrt, a, device);
TIME(mx::rsqrt, a, device);
TIME(mlx::core::exp, a, device);
a = random::uniform({M, N});
a = mx::random::uniform({M, N});
TIME(mlx::core::log, a, device);
}
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
auto condition = mx::random::randint(0, 2, {M, N, K});
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(add, a, b, device);
TIME(subtract, a, b, device);
TIME(multiply, a, b, device);
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
TIME(mx::add, a, b, device);
TIME(mx::subtract, a, b, device);
TIME(mx::multiply, a, b, device);
TIME(mx::divide, a, b, device);
TIME(mx::maximum, a, b, device);
TIME(mx::minimum, a, b, device);
TIME(mx::where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = mx::array({true});
b = mx::random::uniform({1});
mx::eval(b);
TIMEM("scalar", mx::add, a, b, device);
TIMEM("vector-scalar", mx::subtract, a, b, device);
TIMEM("scalar-vector", mx::subtract, b, a, device);
TIMEM("scalar", mx::multiply, a, b, device);
TIMEM("vector-scalar", mx::divide, a, b, device);
TIMEM("scalar-vector", mx::divide, b, a, device);
TIMEM("scalar-vector", mx::where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
mx::eval(a, b);
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
}
void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50;
auto a = random::uniform({M, N, O, P});
auto b = random::uniform({M, N, O, P});
auto device = default_device();
eval(a, b);
TIMEM("non-strided", add, a, b, device);
a = transpose(a, {1, 0, 2, 3});
b = transpose(b, {3, 2, 0, 1});
eval(a, b);
TIMEM("strided", add, a, b, device);
auto a = mx::random::uniform({M, N, O, P});
auto b = mx::random::uniform({M, N, O, P});
auto device = mx::default_device();
mx::eval(a, b);
TIMEM("non-strided", mx::add, a, b, device);
a = mx::transpose(a, {1, 0, 2, 3});
b = mx::transpose(b, {3, 2, 0, 1});
mx::eval(a, b);
TIMEM("strided", mx::add, a, b, device);
}
void time_comparisons() {
int M = 1000, N = 100, K = 10;
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(equal, a, b, device);
TIME(greater, a, b, device);
TIME(greater_equal, a, b, device);
TIME(less, a, b, device);
TIME(less_equal, a, b, device);
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::equal, a, b, device);
TIME(mx::greater, a, b, device);
TIME(mx::greater_equal, a, b, device);
TIME(mx::less, a, b, device);
TIME(mx::less_equal, a, b, device);
}
void time_matvec() {
int M = 2000, N = 200;
auto a = random::uniform({M, N});
auto b = random::uniform({N});
auto c = random::uniform({M});
eval(a, b, c);
auto matvec = [&]() { return matmul(a, b); };
auto a = mx::random::uniform({M, N});
auto b = mx::random::uniform({N});
auto c = mx::random::uniform({M});
mx::eval(a, b, c);
auto matvec = [&]() { return mx::matmul(a, b); };
TIME(matvec);
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
TIME(matvec_transpose);
}
void time_matmul() {
int M = 1000, N = 1000, K = 1000;
auto a = random::uniform({M, K});
auto b = random::uniform({K, N});
auto device = default_device();
eval(a, b);
TIME(matmul, a, b, device);
auto a = mx::random::uniform({M, K});
auto b = mx::random::uniform({K, N});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::matmul, a, b, device);
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
TIME(transpose_matmul);
}
void time_reductions() {
auto a = random::normal({10000, 1000});
eval(a);
auto sum_all = [&a]() { return sum(a, false); };
auto a = mx::random::normal({10000, 1000});
mx::eval(a);
auto sum_all = [&a]() { return mx::sum(a, false); };
TIME(sum_all);
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
TIME(sum_along_0);
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
TIME(sum_along_1);
auto prod_all = [&a]() { return prod(a, false); };
auto prod_all = [&a]() { return mx::prod(a, false); };
TIME(prod_all);
auto all_true = [&a]() { return all(a, false); };
auto all_true = [&a]() { return mx::all(a, false); };
TIME(all_true);
auto all_along_0 = [&a]() { return all(a, 0, false); };
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
TIME(all_along_0);
auto all_along_1 = [&a]() { return all(a, 1, false); };
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
TIME(all_along_1);
auto any_true = [&a]() { return any(a, false); };
auto any_true = [&a]() { return mx::any(a, false); };
TIME(any_true);
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
}
void time_gather_scatter() {
auto a = random::normal({1000, 768});
eval(a);
auto indices = random::randint(0, 1000, {256});
eval(indices);
auto a = mx::random::normal({1000, 768});
mx::eval(a);
auto indices = mx::random::randint(0, 1000, {256});
mx::eval(indices);
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
TIME(embedding_lookup);
indices = random::randint(0, 768 * 1000, {256 * 768});
eval(indices);
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
mx::eval(indices);
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
auto single_element_lookup = [&a, &indices]() {
return mx::take(a, indices);
};
TIME(single_element_lookup);
indices = random::randint(0, 1000, {256});
auto updates = random::normal({256, 1, 768});
eval(indices, updates);
indices = mx::random::randint(0, 1000, {256});
auto updates = mx::random::normal({256, 1, 768});
mx::eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -223,10 +225,10 @@ void time_gather_scatter() {
};
TIME(embedding_add);
a = reshape(a, {-1});
indices = random::randint(0, 768 * 1000, {768 * 256});
updates = random::normal({256 * 768, 1});
eval(a, indices, updates);
a = mx::reshape(a, {-1});
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
updates = mx::random::normal({256 * 768, 1});
mx::eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -240,21 +242,21 @@ void time_gather_scatter() {
}
void time_divmod() {
auto a = random::normal({1000});
auto b = random::normal({1000});
eval({a, b});
auto a = mx::random::normal({1000});
auto b = mx::random::normal({1000});
mx::eval({a, b});
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_creation_ops();
time_type_conversions();
time_unary_ops();

View File

@@ -12,7 +12,7 @@ dtype = mx.float16
loops = 10
def attention(q, k, v):
def attention(q, k, v, mask=None):
def _sdpa(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
@@ -20,6 +20,9 @@ def attention(q, k, v):
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
if mask is not None:
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
s = mx.where(m, s, mx.finfo(s.dtype).min)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
@@ -29,9 +32,9 @@ def attention(q, k, v):
return q
def sdpa(q, k, v):
def sdpa(q, k, v, mask=None):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
return q
@@ -53,6 +56,26 @@ def time_self_attention_sdpa():
time_fn(sdpa, q, k, v)
def time_self_attention_sdpa_with_mask():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mask = mx.full((L,), True)
mask[L // 2 :] = False
mx.eval(q, k, v, mask)
def sdpa_mask(*args):
return sdpa(*args, mask=mask)
def attention_mask(*args):
return attention(*args, mask=mask)
time_fn(attention_mask, q, k, v)
time_fn(sdpa_mask, q, k, v)
if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()
time_self_attention_sdpa_with_mask()

121
docs/src/dev/mlx_in_cpp.rst Normal file
View File

@@ -0,0 +1,121 @@
.. _mlx_in_cpp:
Using MLX in C++
================
You can use MLX in a C++ project with CMake.
.. note::
This guide is based one the following `example using MLX in C++
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
First install MLX:
.. code-block:: bash
pip install -U mlx
You can also install the MLX Python package from source or just the C++
library. For more information see the :ref:`documentation on installing MLX
<build_and_install>`.
Next make an example program in ``example.cpp``:
.. code-block:: C++
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}
The next step is to setup a CMake file in ``CMakeLists.txt``:
.. code-block:: cmake
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Depending on how you installed MLX, you may need to tell CMake where to
find it.
If you installed MLX with Python, then add the following to the CMake file:
.. code-block:: cmake
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
If you installed the MLX C++ package to a system path, then CMake should be
able to find it. If you installed it to a non-standard location or CMake can't
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
.. code-block:: cmake
set(MLX_ROOT "/path/to/mlx/")
Next, instruct CMake to find MLX:
.. code-block:: cmake
find_package(MLX CONFIG REQUIRED)
Finally, add the ``example.cpp`` program as an executable and link MLX.
.. code-block:: cmake
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)
You can build the example with:
.. code-block:: bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
And run it with:
.. code-block:: bash
./build/example
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
.. list-table:: Package Variables
:widths: 20 20
:header-rows: 1
* - Variable
- Description
* - MLX_FOUND
- ``True`` if MLX is found
* - MLX_INCLUDE_DIRS
- Include directory
* - MLX_LIBRARIES
- Libraries to link against
* - MLX_CXX_FLAGS
- Additional compiler flags
* - MLX_BUILD_ACCELERATE
- ``True`` if MLX was built with Accelerate
* - MLX_BUILD_METAL
- ``True`` if MLX was built with Metal

View File

@@ -45,6 +45,7 @@ are the CPU and GPU.
usage/numpy
usage/distributed
usage/using_streams
usage/export
.. toctree::
:caption: Examples
@@ -61,6 +62,7 @@ are the CPU and GPU.
python/array
python/data_types
python/devices_and_streams
python/export
python/ops
python/random
python/transforms
@@ -86,3 +88,4 @@ are the CPU and GPU.
dev/extensions
dev/metal_debugger
dev/custom_metal_kernels
dev/mlx_in_cpp

View File

@@ -1,3 +1,5 @@
.. _build_and_install:
Build and Install
=================
@@ -53,7 +55,7 @@ Build Requirements
^^^^^^^^^^^^^^^^^^
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
.. note::

View File

@@ -66,3 +66,4 @@ documentation for more information. Use :func:`issubdtype` to determine if one
Dtype
DtypeCategory
issubdtype
finfo

View File

@@ -0,0 +1,14 @@
.. _export:
Export Functions
================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
export_function
import_function
exporter
export_to_dot

View File

@@ -89,6 +89,7 @@ Operations
isneginf
isposinf
issubdtype
kron
left_shift
less
less_equal
@@ -144,6 +145,8 @@ Operations
sign
sin
sinh
slice
slice_update
softmax
sort
split
@@ -168,6 +171,7 @@ Operations
tri
tril
triu
unflatten
var
view
where

View File

@@ -421,3 +421,77 @@ the most opportunity to optimize the computation graph:
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)
.. _shapeless_compile:
Shapeless Compilation
---------------------
When the shape of an input to a compiled function changes, the function is
recompiled. You can compile a function once and run it on inputs with
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
case changes to the shapes of the inputs do not cause the function to be
recompiled.
.. code-block:: python
def fun(x, y):
return mx.abs(x + y)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.array(1.0)
y = mx.array(-2.0)
# Firt call compiles the function
print(compiled_fun(x, y))
# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))
Use shapeless compilations carefully. Since compilation is not triggered when
shapes change, any graphs which are conditional on the input shapes will not
work as expected. Shape-dependent computations are common and sometimes subtle
to detect. For example:
.. code-block:: python
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)
The second call to the ``compiled_fun`` fails because of the call to
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
.. code-block:: python
def fun(x):
return x.flatten(0, 1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Ok
out = compiled_fun(x)

View File

@@ -141,12 +141,13 @@ everything else remaining the same.
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init()
N = mx.distributed.init().size()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads)
lambda x: mx.distributed.all_sum(x) / N,
grads
)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)

288
docs/src/usage/export.rst Normal file
View File

@@ -0,0 +1,288 @@
.. _export_usage:
Exporting Functions
===================
.. currentmodule:: mlx.core
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)
To export a function, provide sample input arrays that the function
can be called with. The data doesn't matter, but the shapes and types of the
arrays do. In the above example we exported ``fun`` with two ``float32``
scalar arrays. We can then import the function and run it:
.. code-block:: python
add_fun = mx.import_function("add.mlxfn")
out, = add_fun(mx.array(1.0), mx.array(2.0))
# Prints: array(3, dtype=float32)
print(out)
out, = add_fun(mx.array(1.0), mx.array(3.0))
# Prints: array(4, dtype=float32)
print(out)
# Raises an exception
add_fun(mx.array(1), mx.array(3.0))
# Raises an exception
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
shapes and types of the inputs are different than the shapes and types of the
example inputs we exported the function with.
Also notice that even though the original ``fun`` returns a single output
array, the imported function always returns a tuple of one or more arrays.
The inputs to :func:`export_function` and to an imported function can be
specified as variable positional arguments or as a tuple of arrays:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
# Same as above
mx.export_function("add.mlxfn", fun, (x, y))
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y)
# Also ok
out, = imported_fun((x, y))
You can pass example inputs to functions as positional or keyword arguments. If
you use keyword arguments to export the function, then you have to use the same
keyword arguments when calling the imported function.
.. code-block:: python
def fun(x, y):
return x + y
# One argument to fun is positional, the other is a kwarg
mx.export_function("add.mlxfn", fun, x, y=y)
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y=y)
# Also ok
out, = imported_fun((x,), {"y": y})
# Raises since the keyword argument is missing
out, = imported_fun(x, y)
# Raises since the keyword argument has the wrong key
out, = imported_fun(x, z=y)
Exporting Modules
-----------------
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
in the exported function. Here's an example:
.. code-block:: python
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x):
return model(x)
mx.export_function("model.mlxfn", call, mx.zeros(4))
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
parameters are also saved to the ``model.mlxfn`` file.
.. note::
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
If you only want to export the ``Module.__call__`` function without the
parameters, pass them as inputs to the ``call`` wrapper:
.. code-block:: python
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x, **params):
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
Shapeless Exports
-----------------
Just like :func:`compile`, functions can also be exported for dynamically shaped
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
to export a function which can be used for inputs with variable shapes:
.. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
``imported_abs`` would raise an exception with a shape mismatch.
Shapeless exporting works the same as shapeless compilation and should be
used carefully. See the :ref:`documentation on shapeless compilation
<shapeless_compile>` for more information.
Exporting Multiple Traces
-------------------------
In some cases, functions build different computation graphs for different
input arguments. A simple way to manage this is to export to a new file with
each set of inputs. This is a fine option in many cases. But it can be
suboptimal if the exported functions have a large amount of duplicate constant
data (for example the parameters of a :obj:`mlx.nn.Module`).
The export API in MLX lets you export multiple traces of the same function to
a single file by creating an exporting context manager with :func:`exporter`:
.. code-block:: python
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
exporter(mx.array(1.0))
exporter(mx.array(1.0), y=mx.array(0.0))
imported_function = mx.import_function("fun.mlxfn")
# Call the function with y=None
out, = imported_function(mx.array(1.0))
print(out)
# Call the function with y specified
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
Transformations with Imported Functions
---------------------------------------
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
on imported functions just like regular Python functions:
.. code-block:: python
def fun(x):
return mx.sin(x)
x = mx.array(0.0)
mx.export_function("sine.mlxfn", fun, x)
imported_fun = mx.import_function("sine.mlxfn")
# Take the derivative of the imported function
dfdx = mx.grad(lambda x: imported_fun(x)[0])
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
Importing Functions in C++
--------------------------
Importing and running functions in C++ is basically the same as importing and
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
setup a simple C++ project that uses MLX as a library.
Next, export a simple function from Python:
.. code-block:: python
def fun(x, y):
return mx.exp(x + y)
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("fun.mlxfn", fun, x, y)
Import and run the function in C++ with only a few lines of code:
.. code-block:: c++
auto fun = mx::import_function("fun.mlxfn");
auto inputs = {mx::array(1.0), mx::array(1.0)};
auto outputs = fun(inputs);
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.
More Examples
-------------
Here are a few more complete examples exporting more complex functions from
Python and importing and running them in C++:
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_

View File

@@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Comment the following two commands only the MLX C++ library is installed and
# set(MLX_ROOT "/path/to/mlx") directly if needed.
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)

View File

@@ -0,0 +1,26 @@
## Build and Run
Install MLX with Python:
```bash
pip install mlx>=0.22
```
Build the C++ example:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
Run the C++ example:
```
./build/example
```
which should output:
```
array([2, 4, 6], dtype=int32)
```

View File

@@ -0,0 +1,14 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}

View File

@@ -4,19 +4,19 @@
#include "mlx/mlx.h"
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
if (!distributed::is_available()) {
if (!mx::distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = distributed::init();
auto global_group = mx::distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10});
array out = distributed::all_sum(x, global_group);
mx::array x = mx::ones({10});
mx::array out = mx::distributed::all_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@@ -10,7 +10,7 @@
/**
* An example of linear regression with MLX.
*/
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.01;
// True parameters
auto w_star = random::normal({num_features});
auto w_star = mx::random::normal({num_features});
// The input examples (design matrix)
auto X = random::normal({num_examples, num_features});
auto X = mx::random::normal({num_examples, num_features});
// Noisy labels
auto eps = 1e-2 * random::normal({num_examples});
auto y = matmul(X, w_star) + eps;
auto eps = 1e-2 * mx::random::normal({num_examples});
auto y = mx::matmul(X, w_star) + eps;
// Initialize random parameters
array w = 1e-2 * random::normal({num_features});
mx::array w = 1e-2 * mx::random::normal({num_features});
auto loss_fn = [&](array w) {
auto yhat = matmul(X, w);
return (0.5f / num_examples) * sum(square(yhat - y));
auto loss_fn = [&](mx::array w) {
auto yhat = mx::matmul(X, w);
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
};
auto grad_fn = grad(loss_fn);
auto grad_fn = mx::grad(loss_fn);
auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) {
auto grad = grad_fn(w);
w = w - learning_rate * grad;
eval(w);
auto grads = grad_fn(w);
w = w - learning_rate * grads;
mx::eval(w);
}
auto toc = timer::time();
auto loss = loss_fn(w);
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
<< ", Throughput " << throughput << " (it/s)." << std::endl;

View File

@@ -10,7 +10,7 @@
/**
* An example of logistic regression with MLX.
*/
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.1;
// True parameters
auto w_star = random::normal({num_features});
auto w_star = mx::random::normal({num_features});
// The input examples
auto X = random::normal({num_examples, num_features});
auto X = mx::random::normal({num_examples, num_features});
// Labels
auto y = matmul(X, w_star) > 0;
auto y = mx::matmul(X, w_star) > 0;
// Initialize random parameters
array w = 1e-2 * random::normal({num_features});
mx::array w = 1e-2 * mx::random::normal({num_features});
auto loss_fn = [&](array w) {
auto logits = matmul(X, w);
auto loss_fn = [&](mx::array w) {
auto logits = mx::matmul(X, w);
auto scale = (1.0f / num_examples);
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
};
auto grad_fn = grad(loss_fn);
auto grad_fn = mx::grad(loss_fn);
auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) {
auto grad = grad_fn(w);
w = w - learning_rate * grad;
eval(w);
auto grads = grad_fn(w);
w = w - learning_rate * grads;
mx::eval(w);
}
auto toc = timer::time();
auto loss = loss_fn(w);
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
<< throughput << " (it/s)." << std::endl;

View File

@@ -5,27 +5,27 @@
#include "mlx/mlx.h"
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
metal::start_capture("mlx_trace.gputrace");
mx::metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto s2 = new_stream(mx::Device::gpu);
auto s3 = new_stream(mx::Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
auto x = mx::add(a, a, s2);
auto y = mx::add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl;
std::cout << mx::multiply(x, y) << std::endl;
metal::stop_capture();
mx::metal::stop_capture();
}

View File

@@ -5,11 +5,11 @@
#include "mlx/mlx.h"
using namespace mlx::core;
namespace mx = mlx::core;
void array_basics() {
// Make a scalar array:
array x(1.0);
mx::array x(1.0);
// Get the value out of it:
auto s = x.item<float>();
@@ -29,31 +29,31 @@ void array_basics() {
// The datatype should be float32:
auto dtype = x.dtype();
assert(dtype == float32);
assert(dtype == mx::float32);
// Specify the dtype when constructing the array:
x = array(1, int32);
assert(x.dtype() == int32);
x = mx::array(1, mx::int32);
assert(x.dtype() == mx::int32);
x.item<int>(); // OK
// x.item<float>(); // Undefined!
// Make a multidimensional array:
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
// mlx is row-major by default so the first row of this array
// is [1.0, 2.0] and the second row is [3.0, 4.0]
// Make an array of shape {2, 2} filled with ones:
auto y = ones({2, 2});
auto y = mx::ones({2, 2});
// Pointwise add x and y:
auto z = add(x, y);
auto z = mx::add(x, y);
// Same thing:
z = x + y;
// mlx is lazy by default. At this point `z` only
// has a shape and a type but no actual data:
assert(z.dtype() == float32);
assert(z.dtype() == mx::float32);
assert(z.shape(0) == 2);
assert(z.shape(1) == 2);
@@ -63,33 +63,33 @@ void array_basics() {
// and inputs. When `eval` is called on an array (or arrays), the array and
// all of its dependencies are recursively evaluated to produce the result.
// Once an array is evaluated, it has data and is detached from its inputs.
eval(z);
mx::eval(z);
// Of course the array can still be an input to other operations. You can even
// call eval on the array again, this will just be a no-op:
eval(z); // no-op
// Of course the array can still be an input to other operations. You can
// even call eval on the array again, this will just be a no-op:
mx::eval(z); // no-op
// Some functions or methods on arrays implicitly evaluate them. For example
// accessing a value in an array or printing the array implicitly evaluate it:
z = ones({1});
z = mx::ones({1});
z.item<float>(); // implicit evaluation
z = ones({2, 2});
z = mx::ones({2, 2});
std::cout << z << std::endl; // implicit evaluation
}
void automatic_differentiation() {
auto fn = [](array x) { return square(x); };
auto fn = [](mx::array x) { return mx::square(x); };
// Computing the derivative function of a function
auto grad_fn = grad(fn);
auto grad_fn = mx::grad(fn);
// Call grad_fn on the input to get the derivative
auto x = array(1.5);
auto x = mx::array(1.5);
auto dfdx = grad_fn(x);
// dfdx is 2 * x
// Get the second derivative by composing grad with grad
auto d2fdx2 = grad(grad(fn))(x);
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
// d2fdx2 is 2
}

View File

@@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.27)
project(import_mlx LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(eval_mlp eval_mlp.cpp)
target_link_libraries(eval_mlp PRIVATE mlx)
add_executable(train_mlp train_mlp.cpp)
target_link_libraries(train_mlp PRIVATE mlx)

49
examples/export/README.md Normal file
View File

@@ -0,0 +1,49 @@
## Setup
Install MLX:
```bash
pip install mlx>=0.22
```
Build the C++ examples:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
## Run
### Eval MLP
Run the Python script to export the eval function:
```bash
python eval_mlp.py
```
Then run the C++ program to import and run the function:
```
./build/eval_mlp
```
The Python and C++ programs should output the same result.
### Train MLP
Run the Python script to export the model initialization and training
functions:
```bash
python train_mlp.py
```
Then run the C++ program to import and run the functions:
```
./build/train_mlp
```
The Python and C++ programs should output the same results.

View File

@@ -0,0 +1,25 @@
// Copyright © 2024 Apple Inc.
#include <mlx/mlx.h>
#include <iostream>
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
// Make the input
mx::random::seed(42);
auto example_x = mx::random::uniform({batch_size, input_dim});
// Import the function
auto forward = mx::import_function("eval_mlp.mlxfn");
// Call the imported function
auto out = forward({example_x})[0];
std::cout << out << std::endl;
return 0;
}

View File

@@ -0,0 +1,52 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
if __name__ == "__main__":
batch_size = 8
input_dim = 32
output_dim = 10
# Load the model
mx.random.seed(0) # Seed for params
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
mx.eval(model)
# Note, the model parameters are saved in the export function
def forward(x):
return model(x)
mx.random.seed(42) # Seed for input
example_x = mx.random.uniform(shape=(batch_size, input_dim))
mx.export_function("eval_mlp.mlxfn", forward, example_x)
# Import in Python
imported_forward = mx.import_function("eval_mlp.mlxfn")
expected = forward(example_x)
(out,) = imported_forward(example_x)
assert mx.allclose(expected, out)
print(out)

View File

@@ -0,0 +1,35 @@
// Copyright © 2024 Apple Inc.
#include <mlx/mlx.h>
#include <iostream>
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
int output_dim = 10;
auto state = mx::import_function("init_mlp.mlxfn")({});
// Make the input
mx::random::seed(42);
auto example_X = mx::random::normal({batch_size, input_dim});
auto example_y = mx::random::randint(0, output_dim, {batch_size});
// Import the function
auto step = mx::import_function("train_mlp.mlxfn");
// Call the imported function
for (int it = 0; it < 100; ++it) {
state.insert(state.end(), {example_X, example_y});
state = step(state);
eval(state);
auto loss = state.back();
state.pop_back();
if (it % 10 == 0) {
std::cout << "Loss " << loss.item<float>() << std::endl;
}
}
return 0;
}

View File

@@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
if __name__ == "__main__":
batch_size = 8
input_dim = 32
output_dim = 10
def init():
# Seed for the parameter initialization
mx.random.seed(0)
model = MLP(
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
)
optimizer = optim.SGD(learning_rate=1e-1)
optimizer.init(model.parameters())
state = [model.parameters(), optimizer.state]
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
return model, optimizer, tree_structure, state
# Export the model parameter initialization
model, optimizer, tree_structure, state = init()
mx.eval(state)
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
def loss_fn(params, X, y):
model.update(params)
return nn.losses.cross_entropy(model(X), y, reduction="mean")
def step(*inputs):
*state, X, y = inputs
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
optimizer.state = opt_state
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
params = optimizer.apply_gradients(grads, params)
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
return *state, loss
# Make some random data
mx.random.seed(42)
example_X = mx.random.normal(shape=(batch_size, input_dim))
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
# Export one step of SGD
imported_step = mx.import_function("train_mlp.mlxfn")
for it in range(100):
*state, loss = imported_step(*state, example_X, example_y)
if it % 10 == 0:
print(f"Loss {loss.item():.6}")

View File

@@ -18,8 +18,7 @@ find_package(
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------

View File

@@ -19,7 +19,7 @@
#include "mlx/backend/metal/utils.h"
#endif
namespace mlx::core {
namespace my_ext {
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
@@ -32,24 +32,24 @@ namespace mlx::core {
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
mx::array axpby(
const mx::array& x, // Input mx::array x
const mx::array& y, // Input mx::array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
? promoted_dtype
: promote_types(promoted_dtype, float32);
: promote_types(promoted_dtype, mx::float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
auto x_casted = mx::astype(x, out_dtype, s);
auto y_casted = mx::astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
@@ -57,12 +57,12 @@ array axpby(
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
return mx::array(
/* const mx::Shape& shape = */ out_shape,
/* mx::Dtype dtype = */ out_dtype,
/* std::shared_ptr<mx::Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
}
///////////////////////////////////////////////////////////////////////////////
@@ -71,16 +71,16 @@ array axpby(
template <typename T>
void axpby_impl(
const array& x,
const array& y,
array& out,
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
@@ -94,8 +94,8 @@ void axpby_impl(
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
@@ -105,8 +105,8 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
@@ -114,14 +114,14 @@ void Axpby::eval(
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
@@ -136,9 +136,9 @@ void Axpby::eval(
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
@@ -150,10 +150,10 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
copy_inplace(y, out, mx::CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
@@ -175,15 +175,15 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
if (out.dtype() == mx::float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
@@ -198,8 +198,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
eval(inputs, outputs);
}
@@ -213,8 +213,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
@@ -225,7 +225,7 @@ void Axpby::eval_gpu(
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
auto& d = mx::metal::device(s.device);
// Prepare to specialize based on contiguity
bool contiguous_kernel =
@@ -235,12 +235,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)
@@ -279,7 +279,7 @@ void Axpby::eval_gpu(
if (!contiguous_kernel) {
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_vector_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
}
@@ -302,8 +302,8 @@ void Axpby::eval_gpu(
/** Fail evaluation on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& out) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& out) {
throw std::runtime_error("Axpby has no GPU implementation.");
}
@@ -314,9 +314,9 @@ void Axpby::eval_gpu(
///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
std::vector<mx::array> Axpby::jvp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
@@ -328,8 +328,8 @@ std::vector<array> Axpby::jvp(
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
auto scale_arr = mx::array(scale, tangents[0].dtype());
return {mx::multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
@@ -339,24 +339,24 @@ std::vector<array> Axpby::jvp(
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
std::vector<mx::array> Axpby::vjp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<mx::array>&) {
// Reverse mode diff
std::vector<array> vjps;
std::vector<mx::array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = mx::array(scale, cotangents[0].dtype());
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");
}
@@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace mlx::core
} // namespace my_ext

View File

@@ -5,7 +5,9 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace mx = mlx::core;
namespace my_ext {
///////////////////////////////////////////////////////////////////////////////
// Operation
@@ -18,22 +20,22 @@ namespace mlx::core {
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
mx::array axpby(
const mx::array& x, // Input array x
const mx::array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
);
///////////////////////////////////////////////////////////////////////////////
// Primitive
///////////////////////////////////////////////////////////////////////////////
class Axpby : public Primitive {
class Axpby : public mx::Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta) {};
explicit Axpby(mx::Stream stream, float alpha, float beta)
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
@@ -42,23 +44,25 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
void eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
std::vector<mx::array> jvp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
std::vector<mx::array> vjp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<mx::array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
@@ -66,8 +70,8 @@ class Axpby : public Primitive {
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
@@ -76,14 +80,16 @@ class Axpby : public Primitive {
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
bool is_equivalent(const mx::Primitive& other) const override;
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
void eval(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs);
};
} // namespace mlx::core
} // namespace my_ext

View File

@@ -8,14 +8,12 @@
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
m.def(
"axpby",
&axpby,
&my_ext::axpby,
"x"_a,
"y"_a,
"alpha"_a,

View File

@@ -1,8 +1,8 @@
[build-system]
requires = [
"setuptools>=42",
"cmake>=3.24",
"cmake>=3.25",
"mlx>=0.18.0",
"nanobind==2.2.0",
"nanobind==2.4.0",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
cmake>=3.25
mlx>=0.21.0
nanobind==2.2.0

View File

@@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
@@ -18,6 +19,16 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
if(MSVC)
# Disable some MSVC warnings to speed up compilation.
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
endif()
if(WIN32)
# Export symbols by default to behave like macOS/linux.
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
endif()
if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()

View File

@@ -10,22 +10,8 @@
namespace mlx::core {
namespace {
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
return detail::InTracing::in_tracing();
}
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
auto cval = static_cast<complex64_t>(val);
init(&cval);
}
@@ -61,14 +47,14 @@ std::vector<array> array::make_arrays(
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
Shape{static_cast<ShapeElem>(data.size())},
float32)) {
init(data.begin());
}
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
Shape{static_cast<ShapeElem>(data.size())},
dtype)) {
init(data.begin());
}
@@ -119,7 +105,8 @@ void array::eval() {
}
bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph();
return (array_desc_->is_tracer && detail::in_tracing()) ||
detail::retain_graph();
}
void array::set_data(allocator::Buffer buffer, Deleter d) {
@@ -277,7 +264,19 @@ array::ArrayDesc::~ArrayDesc() {
}
ad.inputs.clear();
for (auto& [_, a] : input_map) {
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
bool is_deletable =
(a.array_desc_.use_count() <= a.siblings().size() + 1);
// An array with siblings is deletable only if all of its siblings
// are deletable
for (auto& s : a.siblings()) {
if (!is_deletable) {
break;
}
int is_input = (input_map.find(s.id()) != input_map.end());
is_deletable &=
s.array_desc_.use_count() <= a.siblings().size() + is_input;
}
if (is_deletable) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
@@ -310,7 +309,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
}
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
auto start = std::vector<int>(arr.ndim(), 0);
auto start = Shape(arr.ndim(), 0);
auto end = arr.shape();
auto shape = arr.shape();
shape.erase(shape.begin());

View File

@@ -17,7 +17,8 @@ namespace mlx::core {
class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>;
using ShapeElem = int32_t;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
class array {
@@ -34,29 +35,29 @@ class array {
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
template <typename It>
array(
explicit array(
It data,
Shape shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
template <typename T>
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
/* Special case so empty lists default to float32. */
array(std::initializer_list<float> data);
explicit array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */
array(std::initializer_list<int> data, Dtype dtype);
explicit array(std::initializer_list<int> data, Dtype dtype);
template <typename T>
array(
explicit array(
std::initializer_list<T> data,
Shape shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
explicit array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
@@ -498,7 +499,7 @@ class array {
template <typename T>
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
init(&val);
}
@@ -516,7 +517,7 @@ array::array(
std::initializer_list<T> data,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
Shape{static_cast<ShapeElem>(data.size())},
dtype)) {
init(data.begin());
}

View File

@@ -32,6 +32,7 @@ DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(Broadcast)
DEFAULT(BroadcastAxes)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
@@ -43,6 +44,7 @@ DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(ExpandDims)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
@@ -65,7 +67,6 @@ DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
@@ -76,6 +77,7 @@ DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(Squeeze)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)

View File

@@ -5,13 +5,21 @@ else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
if(MSVC)
set(SHELL_EXT ps1)
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
else()
set(SHELL_EXT sh)
set(SHELL_CMD /bin/bash)
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG}
DEPENDS make_compiled_preamble.sh
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
DEPENDS make_compiled_preamble.${SHELL_EXT}
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
@@ -58,5 +66,6 @@ target_sources(
if(IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
endif()

View File

@@ -28,8 +28,8 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) {
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = BinaryOpType::VectorScalar;
} else if (
a.flags().row_contiguous && b.flags().row_contiguous ||
a.flags().col_contiguous && b.flags().col_contiguous) {
(a.flags().row_contiguous && b.flags().row_contiguous) ||
(a.flags().col_contiguous && b.flags().col_contiguous)) {
bopt = BinaryOpType::VectorVector;
} else {
bopt = BinaryOpType::General;

View File

@@ -42,9 +42,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return move_or_copy(in, out, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
@@ -61,6 +59,14 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
move_or_copy(in, out, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out);
}
void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out);
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
@@ -85,6 +91,16 @@ void Depends::eval(
}
}
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto strides = in.strides();
for (auto ax : axes_) {
strides.insert(strides.begin() + ax, 1);
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -141,9 +157,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
}
}
std::pair<bool, Strides> Reshape::prepare_reshape(
const array& in,
const array& out) {
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
@@ -180,7 +194,7 @@ std::pair<bool, Strides> Reshape::prepare_reshape(
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out) {
@@ -248,6 +262,20 @@ void Split::eval(
}
}
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
Strides strides;
for (int i = 0, j = 0; i < in.ndim(); ++i) {
if (j < axes_.size() && i == axes_[j]) {
j++;
} else {
strides.push_back(in.strides(i));
}
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);

View File

@@ -130,7 +130,7 @@ std::string build_lib_name(
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
const Shape& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;

View File

@@ -11,9 +11,7 @@
namespace mlx::core {
inline bool is_static_cast(const Primitive& p) {
return (
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(
@@ -56,7 +54,7 @@ inline bool is_scalar(const array& x) {
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
const Shape& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(

View File

@@ -2,6 +2,7 @@
#include <dlfcn.h>
#include <filesystem>
#include <format>
#include <fstream>
#include <list>
#include <mutex>
@@ -9,6 +10,7 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/common/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
@@ -44,11 +46,8 @@ namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name).string();
}
} // namespace detail
// Return a pointer to a compiled function
void* compile(
@@ -68,24 +67,30 @@ void* compile(
std::string source_code = source_builder();
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
// characters. Clip file name with a little extra room and append a 16
// character hash.
// Deal with long kernel names. Maximum length for filename on macOS is 255
// characters, and on Windows the maximum length for whole path is 260. Clip
// file name with a little extra room and append a 16 character hash.
#ifdef _WIN32
constexpr int max_file_name_length = 140;
#else
constexpr int max_file_name_length = 245;
#endif
if (kernel_name.size() > max_file_name_length) {
std::ostringstream file_name;
file_name
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
auto file_id = std::hash<std::string>{}(kernel_name);
auto file_id =
std::hash<std::string>{}(kernel_name.substr(max_file_name_length - 16));
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
kernel_file_name = file_name.str();
} else {
kernel_file_name = kernel_name;
}
std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_file_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
auto output_dir = std::filesystem::temp_directory_path();
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();
bool lib_exists = false;
{
std::ifstream f(shared_lib_path.c_str());
@@ -94,24 +99,21 @@ void* compile(
if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_file_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::string source_file_name = kernel_file_name + ".cpp";
auto source_file_path = (output_dir / source_file_name).string();
std::ofstream source_file(source_file_path);
source_file << source_code;
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
<< " with error code " << return_code << "." << std::endl;
throw std::runtime_error(msg.str());
try {
JitCompiler::exec(JitCompiler::build_command(
output_dir, source_file_name, shared_lib_name));
} catch (const std::exception& error) {
throw std::runtime_error(std::format(
"[Compile::eval_cpu] Failed to compile function {0}: {1}",
kernel_name,
error.what()));
}
}
@@ -151,6 +153,11 @@ inline void build_kernel(
NodeNamer namer;
#ifdef _MSC_VER
// Export the symbol
os << "__declspec(dllexport) ";
#endif
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;

View File

@@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu(
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
Shape padded_shape = {N, iH + 2 * padding[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH, wH * C};
Shape strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
@@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu(
auto conv_dtype = out.dtype();
// Pad input
std::vector<int> padded_shape = {
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
@@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu(
const std::vector<int>& wt_dilation,
const bool flip) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>(
const auto iDim =
Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = Shape(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels
const auto wDim = std::vector<int>(
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
const auto wDim =
Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape(in.shape().size());
Shape padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];

View File

@@ -37,6 +37,7 @@ DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BroadcastAxes)
DEFAULT(BlockMaskedMM)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
@@ -57,6 +58,7 @@ DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(ExpandDims)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
@@ -86,7 +88,6 @@ DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
@@ -101,6 +102,7 @@ DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Squeeze)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)

View File

@@ -0,0 +1,153 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/jit_compiler.h"
#include <algorithm>
#include <sstream>
#include <vector>
#include <format>
namespace mlx::core {
#ifdef _MSC_VER
namespace {
// Split string into array.
std::vector<std::string> str_split(const std::string& str, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(str);
while (std::getline(tokenStream, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}
// Get path information about MSVC.
struct VisualStudioInfo {
VisualStudioInfo() {
#ifdef _M_ARM64
arch = "arm64";
#else
arch = "x64";
#endif
// Get path of Visual Studio.
std::string vs_path = JitCompiler::exec(std::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -property installationPath",
std::getenv("ProgramFiles(x86)")));
if (vs_path.empty()) {
throw std::runtime_error("Can not find Visual Studio.");
}
// Read the envs from vcvarsall.
std::string envs = JitCompiler::exec(std::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
vs_path,
arch));
for (const std::string& line : str_split(envs, '\n')) {
// Each line is in the format "ENV_NAME=values".
auto pos = line.find_first_of('=');
if (pos == std::string::npos || pos == 0 || pos == line.size() - 1)
continue;
std::string name = line.substr(0, pos);
std::string value = line.substr(pos + 1);
if (name == "LIB") {
libpaths = str_split(value, ';');
} else if (name == "VCToolsInstallDir") {
cl_exe = std::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
}
}
}
std::string arch;
std::string cl_exe;
std::vector<std::string> libpaths;
};
const VisualStudioInfo& GetVisualStudioInfo() {
static VisualStudioInfo info;
return info;
}
} // namespace
#endif // _MSC_VER
std::string JitCompiler::build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name) {
#ifdef _MSC_VER
const VisualStudioInfo& info = GetVisualStudioInfo();
std::string libpaths;
for (const std::string& lib : info.libpaths) {
libpaths += std::format(" /libpath:\"{0}\"", lib);
}
return std::format(
"\""
"cd /D \"{0}\" && "
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
"/link /out:\"{3}\" {4} 2>&1"
"\"",
dir.string(),
info.cl_exe,
source_file_name,
shared_lib_name,
libpaths);
#else
return std::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1",
(dir / source_file_name).string(),
(dir / shared_lib_name).string());
#endif
}
std::string JitCompiler::exec(const std::string& cmd) {
#ifdef _MSC_VER
FILE* pipe = _popen(cmd.c_str(), "r");
#else
FILE* pipe = popen(cmd.c_str(), "r");
#endif
if (!pipe) {
throw std::runtime_error("popen() failed.");
}
char buffer[128];
std::string ret;
while (fgets(buffer, sizeof(buffer), pipe)) {
ret += buffer;
}
// Trim trailing spaces.
ret.erase(
std::find_if(
ret.rbegin(),
ret.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
ret.end());
#ifdef _MSC_VER
int status = _pclose(pipe);
#else
int status = pclose(pipe);
#endif
if (status == -1) {
throw std::runtime_error("pclose() failed.");
}
#ifdef _MSC_VER
int code = status;
#else
int code = WEXITSTATUS(status);
#endif
if (code != 0) {
throw std::runtime_error(std::format(
"Failed to execute command with return code {0}: \"{1}\", "
"the output is: {2}",
code,
cmd,
ret));
}
return ret;
}
} // namespace mlx::core

View File

@@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <filesystem>
namespace mlx::core {
class JitCompiler {
public:
// Build a shell command that compiles a source code file to a shared library.
static std::string build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name);
// Run a command and get its output.
static std::string exec(const std::string& cmd);
};
} // namespace mlx::core

View File

@@ -0,0 +1,38 @@
# This script generates a C++ function that provides the CPU
# code for use with kernel generation.
#
# Copyright © 2024 Apple Inc.
$OUTPUT_FILE = $args[0]
$CL = $args[1]
$SRCDIR = $args[2]
# Get command result as array.
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
# Remove empty lines.
# Otherwise there will be too much empty lines making the result unreadable.
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
# Concatenate to string.
$CONTENT = $CONTENT -join "`n"
# Append extra content.
$CONTENT = @"
$($CONTENT)
using namespace mlx::core;
using namespace mlx::core::detail;
"@
# Convert each char to ASCII code.
# Unlike the unix script that outputs string literal directly, the output from
# MSVC is way too large to be embedded as string and compilation will fail, so
# we store it as static array instead.
$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'
$OUTPUT = @"
const char* get_kernel_preamble() {
static char preamble[] = { $CHARCODES };
return preamble;
}
"@
Set-Content -Path $OUTPUT_FILE -Value $OUTPUT

View File

@@ -10,15 +10,16 @@ OUTPUT_FILE=$1
GCC=$2
SRCDIR=$3
CLANG=$4
ARCH=$5
if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
EOM
CC_FLAGS=""
CC_FLAGS="-arch ${ARCH}"
else
CC_FLAGS="-std=c++17"
fi

View File

@@ -19,6 +19,45 @@
namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
int64_t compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes) {
auto compute_offset = [&strides, &axes](const auto* indices) {
int64_t offset = 0;
for (int i = 0; i < axes.size(); ++i) {
offset += indices[i] * strides[axes[i]];
}
return offset;
};
switch (indices.dtype()) {
case int8:
case uint8:
return compute_offset(indices.data<uint8_t>());
case int16:
case uint16:
return compute_offset(indices.data<uint16_t>());
case int32:
case uint32:
return compute_offset(indices.data<uint32_t>());
case int64:
case uint64:
return compute_offset(indices.data<uint64_t>());
default:
throw std::runtime_error("Invalid indices type.");
}
}
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -258,6 +297,14 @@ void Expm1::eval(const std::vector<array>& inputs, array& out) {
}
}
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -417,18 +464,8 @@ void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Round::eval(const std::vector<array>& inputs, array& out) {
@@ -499,34 +536,64 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
auto copy_needed = std::any_of(
strides_.begin(), strides_.end(), [](auto i) { return i < 0; });
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Strides ostrides{out.strides().begin(), out.strides().end()};
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ i_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
// Copy or move src to dst
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ o_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
@@ -554,12 +621,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy
Strides upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides,
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
@@ -615,7 +681,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {

View File

@@ -14,10 +14,10 @@ namespace mlx::core {
namespace {
template <typename T, typename IdxT = int32_t>
template <typename T>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = IdxT;
using difference_type = int32_t;
using value_type = T;
using reference = value_type&;
using pointer = value_type*;

View File

@@ -67,7 +67,12 @@ void set_ternary_op_output_data(
}
break;
case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
break;
}
}

View File

@@ -107,7 +107,7 @@ struct ContiguousIterator {
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
pos_ = Shape(shape_.size(), 0);
}
}
@@ -168,4 +168,10 @@ void move_or_copy(
size_t data_size,
size_t offset = 0);
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
} // namespace mlx::core

View File

@@ -34,16 +34,20 @@ BufferCache::~BufferCache() {
clear();
}
void BufferCache::clear() {
int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf)
if (holder->buf) {
holder->buf->release();
n_release++;
}
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
@@ -81,10 +85,11 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
}
}
void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
clear();
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
@@ -92,10 +97,12 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
total_bytes_freed += tail_->buf->length();
tail_->buf->release();
tail_->buf = nullptr;
n_release++;
}
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
@@ -144,11 +151,11 @@ MetalAllocator::MetalAllocator()
residency_set_(device_),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ =
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize);
gc_limit_ = std::min(
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_);
auto max_rec_size =
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
@@ -186,7 +193,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// More helpful message if maximum buffer length is exceeded
if (size > device_->maxBufferLength()) {
std::ostringstream msg;
msg << "Attempting to allocate " << size << " bytes which is greater than"
msg << "[metal::malloc] Attempting to allocate " << size
<< " bytes which is greater than"
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
<< " bytes.";
throw std::runtime_error(msg.str());
@@ -212,16 +220,26 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache
if (mem_required >= gc_limit_) {
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
num_resources_ -=
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
}
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
if (num_resources_ >= resource_limit_) {
std::ostringstream msg;
msg << "[metal::malloc] Resource limit (" << resource_limit_
<< ") exceeded.";
throw std::runtime_error(msg.str());
}
lk.unlock();
buf = device_->newBuffer(size, res_opt);
lk.lock();
if (buf) {
num_resources_++;
}
}
active_memory_ += buf->length();
@@ -230,7 +248,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
auto pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
num_resources_ -= buffer_cache_.release_cached_buffers(
get_cache_memory() - max_pool_size_);
}
residency_set_.insert(buf);
@@ -241,7 +260,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
void MetalAllocator::clear_cache() {
std::unique_lock lk(mutex_);
auto pool = metal::new_scoped_memory_pool();
buffer_cache_.clear();
num_resources_ -= buffer_cache_.clear();
}
void MetalAllocator::free(Buffer buffer) {
@@ -255,6 +274,7 @@ void MetalAllocator::free(Buffer buffer) {
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
num_resources_--;
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();

View File

@@ -23,11 +23,11 @@ class BufferCache {
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free);
int release_cached_buffers(size_t min_bytes_to_free);
size_t cache_size() {
return pool_size_;
}
void clear();
int clear();
private:
struct BufferHolder {
@@ -94,6 +94,8 @@ class MetalAllocator : public allocator::Allocator {
size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true};
size_t num_resources_{0};
size_t resource_limit_{0};
std::mutex mutex_;
};

View File

@@ -81,13 +81,15 @@ void binary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool large = out.data_size() > UINT32_MAX;
bool large;
auto ndim = shape.size();
int work_per_thread;
if (bopt == BinaryOpType::General) {
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out.size() > INT32_MAX;
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = 1;
}
std::string kernel_name =

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream> //TODO
#include <format>
#include <sstream>
#include "mlx/backend/common/compiled.h"
@@ -12,8 +11,6 @@
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core {
inline void build_kernel(
@@ -42,7 +39,7 @@ inline void build_kernel(
int cnt = 0;
// Start the kernel
os += fmt::format(
os += std::format(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
// Add the input arguments
@@ -58,7 +55,7 @@ inline void build_kernel(
if (!is_scalar(x) && !contiguous) {
add_indices = true;
}
os += fmt::format(
os += std::format(
" device const {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
xname,
@@ -66,13 +63,13 @@ inline void build_kernel(
}
if (add_indices) {
os += fmt::format(
os += std::format(
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
}
// Add the output arguments
for (auto& x : outputs) {
os += fmt::format(
os += std::format(
" device {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
namer.get_name(x),
@@ -80,13 +77,13 @@ inline void build_kernel(
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os += fmt::format(
os += std::format(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
os += std::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
}
if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
os += std::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
}
// The thread index in the whole grid
@@ -99,15 +96,15 @@ inline void build_kernel(
// a third grid dimension
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
os += std::format(" constexpr int N_ = {0};\n", work_per_thread);
os += std::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
os += fmt::format(
os += std::format(
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
} else {
os += fmt::format(
os += std::format(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
}
@@ -122,16 +119,16 @@ inline void build_kernel(
auto type_str = get_type_string(x.dtype());
std::ostringstream ss;
print_constant(ss, x);
os += fmt::format(
os += std::format(
" auto tmp_{0} = static_cast<{1}>({2});\n",
xname,
get_type_string(x.dtype()),
ss.str());
} else if (is_scalar(x)) {
os += fmt::format(
os += std::format(
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
} else if (contiguous) {
os += fmt::format(
os += std::format(
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
} else {
nc_inputs.push_back(x);
@@ -141,30 +138,30 @@ inline void build_kernel(
// Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
os += std::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) {
int offset = i * ndim;
os +=
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
std::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
} else if (ndim == 2) {
int offset = i * ndim;
os += fmt::format(
os += std::format(
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
idx_type,
offset);
} else if (ndim == 3) {
int offset = i * ndim;
os += fmt::format(
os += std::format(
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
} else if (!dynamic_dims) {
int offset = (i + 1) * ndim;
os += fmt::format(
os += std::format(
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
idx_type,
offset - 1,
offset - 2);
} else {
os += fmt::format(
os += std::format(
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
idx_type,
i);
@@ -176,18 +173,18 @@ inline void build_kernel(
if (dynamic_dims) {
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
} else {
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
os += std::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
}
os += " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" index_{0} += ", xname);
os += std::format(" index_{0} += ", xname);
if (dynamic_dims) {
os +=
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
std::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
} else {
os +=
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
std::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
}
}
os += " zpos /= output_shape[d];\n }\n";
@@ -203,16 +200,16 @@ inline void build_kernel(
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
os += fmt::format(
os += std::format(
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
}
// Actually write the computation
for (auto& x : tape) {
os += fmt::format(
os += std::format(
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
if (is_static_cast(x.primitive())) {
os += fmt::format(
os += std::format(
"static_cast<{0}>(tmp_{1});\n",
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
@@ -222,15 +219,15 @@ inline void build_kernel(
os += ss.str();
os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
os += std::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
}
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
os += std::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
os += std::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
// Increment indices and close per thread loop
if (work_per_thread > 1) {
@@ -238,10 +235,10 @@ inline void build_kernel(
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
if (!dynamic_dims) {
os += fmt::format(
os += std::format(
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
} else {
os += fmt::format(
os += std::format(
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
}
}

View File

@@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu(
int implicit_K = wt.size() / conv_params.O;
int implicit_N = conv_params.O;
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K};
Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu(
}
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@@ -192,12 +192,12 @@ void conv_1D_gpu(
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(2),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1)},
/* const int wS[NDIM] = */ {wt.shape(1)},
/* const int oS[NDIM] = */ {out.shape(1)},
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
@@ -541,7 +541,7 @@ void winograd_conv_2D_gpu(
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
std::vector<int> padded_shape = {
Shape padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
@@ -550,7 +550,7 @@ void winograd_conv_2D_gpu(
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
array in_padded(padded_shape, in.dtype(), nullptr, {});
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
// Fill with zeros
array zero_arr = array(0, in.dtype());
@@ -575,12 +575,16 @@ void winograd_conv_2D_gpu(
copies_w.push_back(in_padded);
MLXConvParams<2> conv_params_updated{
/* const int N = */ in_padded.shape(0),
/* const int C = */ in_padded.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int N = */ static_cast<int>(in_padded.shape(0)),
/* const int C = */ static_cast<int>(in_padded.shape(3)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in_padded.shape(1)),
static_cast<int>(in_padded.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int kdil[NDIM] = */ {1, 1},
@@ -607,8 +611,8 @@ void winograd_conv_2D_gpu(
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
// Do filter transform
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
copies_w.push_back(filt_wg);
{
@@ -634,8 +638,8 @@ void winograd_conv_2D_gpu(
}
// Do input transform
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
copies_w.push_back(inp_wg);
{
@@ -661,8 +665,8 @@ void winograd_conv_2D_gpu(
}
// Do batched gemm
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
copies_w.push_back(out_wg);
{
@@ -723,12 +727,15 @@ void conv_2D_gpu(
std::vector<array>& copies) {
// Make conv params
MLXConvParams<2> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(3)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
/* const int pad[NDIM] = */ {padding[0], padding[1]},
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
@@ -800,12 +807,21 @@ void conv_3D_gpu(
std::vector<array>& copies) {
// Make conv params
MLXConvParams<3> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(4),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(4)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in.shape(1)),
static_cast<int>(in.shape(2)),
static_cast<int>(in.shape(3))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)),
static_cast<int>(wt.shape(2)),
static_cast<int>(wt.shape(3))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)),
static_cast<int>(out.shape(2)),
static_cast<int>(out.shape(3))},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
/* const int kdil[NDIM] = */

View File

@@ -52,7 +52,9 @@ void copy_gpu_inplace(
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s) {
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
if (out.size() == 0) {
return;
}
@@ -80,6 +82,7 @@ void copy_gpu_inplace(
} else {
large = out.data_size() > UINT32_MAX;
}
bool dynamic = dynamic_i_offset || dynamic_o_offset;
auto& d = metal::device(s.device);
int work_per_thread = 1;
std::string kernel_name;
@@ -107,9 +110,17 @@ void copy_gpu_inplace(
if (large) {
kernel_name += "large";
}
if (dynamic) {
kernel_name += "_dynamic";
if (ctype != CopyType::GeneralGeneral) {
throw std::runtime_error(
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
}
}
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
: get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -145,6 +156,18 @@ void copy_gpu_inplace(
compute_encoder.set_bytes(ndim, 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
}
if (dynamic) {
if (dynamic_i_offset) {
compute_encoder.set_input_array(*dynamic_i_offset, 6);
} else {
compute_encoder.set_bytes(0ll, 6);
}
if (dynamic_o_offset) {
compute_encoder.set_input_array(*dynamic_o_offset, 7);
} else {
compute_encoder.set_bytes(0ll, 7);
}
}
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
if (thread_group_size != 1024) {
@@ -179,13 +202,13 @@ void copy_gpu_inplace(
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& istride,
int64_t ioffset,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s);
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
void fill_gpu(const array& val, array& out, const Stream& s) {

View File

@@ -17,13 +17,15 @@ void copy_gpu_inplace(
int64_t i_offset,
int64_t o_offset,
CopyType ctype,
const Stream& s);
const Stream& s,
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype);
void copy_gpu_inplace(
const array& src,
const array& in,
array& out,
CopyType ctype,
const Stream& s);
@@ -31,8 +33,8 @@ void copy_gpu_inplace(
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& istride,
int64_t ioffset,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s);

View File

@@ -651,18 +651,23 @@ device_info() {
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
sysctl(mib, 2, &memsize, &length, NULL, 0);
size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}};
{"memory_size", memsize},
{"resource_limit", rsrc_limit}};
};
static auto device_info_ = init_device_info();
return device_info_;

View File

@@ -3,25 +3,20 @@
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::distributed {
void signal_and_wait(const array& in, const array& out, const Stream& s) {
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
if (in.event().valid()) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(in.event().raw_event().get()),
in.event().value());
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
if (e_signal.valid()) {
encode_signal(e_signal);
}
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
encode_wait(e_wait);
}
void AllReduce::eval_gpu(
@@ -38,8 +33,12 @@ void AllReduce::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
auto task = [in = in,
out = out,
e = std::move(e),
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
@@ -53,11 +52,9 @@ void AllReduce::eval_gpu(
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
out.event().signal();
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out, stream());
}
void AllGather::eval_gpu(
@@ -70,15 +67,19 @@ void AllGather::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto task = [in = in, out = out, group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::all_gather(group, in, out);
out.event().signal();
};
auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
auto task =
[in = in, out = out, e = std::move(e), group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::all_gather(group, in, out);
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out, stream());
}
void Send::eval_gpu(
@@ -89,27 +90,20 @@ void Send::eval_gpu(
auto& in = inputs[0];
auto& out = outputs[0];
move_or_copy(in, out);
// Schedule an async send on the comm stream
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::send(group, in, dst);
out.event().signal();
distributed::detail::send(group, out, dst);
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a signal event for the input but not a wait since we don't need to
// wait on the output.
auto& s = stream();
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
// Encode a signal event for the input
if (in.event().valid()) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(in.event().raw_event().get()),
in.event().value());
encode_signal(in.event());
}
}
@@ -123,20 +117,18 @@ void Recv::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Schedule an async recv on the comm stream
auto task = [out = out, group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
auto e = Event(stream());
e.set_value(1);
// Encode a wait event as there is no input for the recv to encode a signal.
auto& s = stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
encode_wait(e);
// Schedule an async recv on the comm stream
auto task =
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
} // namespace mlx::core::distributed

View File

@@ -6,6 +6,26 @@
namespace mlx::core {
void encode_wait(Event e) {
auto& d = metal::device(e.stream().device);
d.end_encoding(e.stream().index);
auto command_buffer = d.get_command_buffer(e.stream().index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
command_buffer->addCompletedHandler(
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
}
void encode_signal(Event e) {
auto& d = metal::device(e.stream().device);
d.end_encoding(e.stream().index);
auto command_buffer = d.get_command_buffer(e.stream().index);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
command_buffer->addCompletedHandler(
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
}
Event::Event(const Stream& stream) : stream_(stream) {
auto dtor = [](void* ptr) {
auto p = metal::new_scoped_memory_pool();

10
mlx/backend/metal/event.h Normal file
View File

@@ -0,0 +1,10 @@
// Copyright © 2024 Apple Inc.
#pragma once
namespace mlx::core {
void encode_wait(Event e);
void encode_signal(Event e);
} // namespace mlx::core

View File

@@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <format>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
@@ -20,9 +20,9 @@ std::pair<std::string, std::string> make_index_args(
std::ostringstream idx_args;
std::ostringstream idx_arr;
for (int i = 0; i < nidx; ++i) {
idx_args << fmt::format(
idx_args << std::format(
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
idx_arr << fmt::format("idx{0}", i);
idx_arr << std::format("idx{0}", i);
if (i < nidx - 1) {
idx_args << "\n";
idx_arr << ",";
@@ -53,19 +53,19 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
bool large_src = src.size() > UINT32_MAX;
bool large_out = out.size() > UINT32_MAX;
bool large_index = nidx && inputs[1].size() > INT32_MAX;
bool large_src = src.size() > INT32_MAX;
bool large_out = out.size() > INT32_MAX;
bool large = large_index || large_src || large_out;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string kernel_name = fmt::format(
std::string kernel_name = std::format(
"gather{0}{1}_{2}_{3}_{4}",
type_to_name(out),
idx_type_name,
nidx,
idx_ndim,
large ? "int64_t" : "uint");
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
@@ -77,7 +77,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations
kernel_source += fmt::format(
kernel_source += std::format(
gather_kernels,
type_to_name(out) + idx_type_name,
out_type_str,
@@ -86,7 +86,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args,
idx_arr,
idx_ndim,
large ? "int64_t" : "uint");
large ? "int64_t" : "int");
return kernel_source;
});
@@ -234,11 +234,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
auto upd_contig = upd.flags().row_contiguous;
bool large_out = out.size() > UINT32_MAX;
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
bool large_upd = upd.size() > UINT32_MAX;
bool large_out = out.size() > INT32_MAX;
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
bool large_upd = upd.size() > INT32_MAX;
bool large = large_out || large_idx || large_upd;
std::string kernel_name = fmt::format(
std::string kernel_name = std::format(
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
type_to_name(out),
idx_type_name,
@@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx,
upd_contig ? "updc_true" : "updc_false",
nwork,
large ? "int64_t" : "uint");
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
@@ -275,11 +275,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
if (reduce_type_ != Scatter::None) {
op_type = fmt::format(fmt::runtime(op_type), out_type_str);
op_type = std::vformat(op_type, std::make_format_args(out_type_str));
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source += fmt::format(
kernel_source += std::format(
scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str,
@@ -290,7 +290,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_arr,
upd_contig,
nwork,
large ? "int64_t" : "uint");
large ? "int64_t" : "int");
return kernel_source;
});

View File

@@ -1,25 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view gemv_masked_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
const device {itype}* mat [[buffer(0)]],
const device {itype}* in_vec [[buffer(1)]],
device {itype}* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant int64_t* matrix_batch_stride [[buffer(12)]],
const device {outm_t}* out_mask [[buffer(20)]],
const device {opm_t}* mat_mask [[buffer(21)]],
const device {opm_t}* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant int64_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@@ -1,32 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_conv_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";
constexpr std::string_view steel_conv_general_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@@ -1,106 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_gemm_fused_kernels = R"(
template [[host_name("{name}")]]
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
const device {itype} *A [[buffer(0)]],
const device {itype} *B [[buffer(1)]],
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
device {itype} *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_masked_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
block_masked_gemm<
{itype},
{outmasktype},
{opmasktype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]],
const device {outmasktype}* out_mask [[buffer(10)]],
const device {opmasktype}* lhs_mask [[buffer(11)]],
const device {opmasktype}* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk<
{itype},
{otype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {otype}* C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum_axpby<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device {otype}* C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]);
)";

View File

@@ -1,16 +1,11 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
using namespace fmt::literals;
namespace mlx::core {
std::string op_name(const array& arr) {
@@ -26,7 +21,7 @@ MTL::ComputePipelineState* get_arange_kernel(
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::arange()
<< fmt::format(
<< std::format(
arange_kernels,
kernel_name,
get_type_string(out.dtype()));
@@ -52,7 +47,7 @@ MTL::ComputePipelineState* get_unary_kernel(
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source;
@@ -74,7 +69,7 @@ void append_binary_kernels(
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"},
{"g1large", "binary_g_nd1"},
{"g2large", "binary_g_nd2"},
{"g3large", "binary_g_nd3"},
}};
@@ -86,11 +81,13 @@ void append_binary_kernels(
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
}
kernel_source += get_template_definition(
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
}
@@ -141,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g1", "ternary_g_nd1"},
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
}};
@@ -150,11 +147,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
kernel_source += get_template_definition(
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
return kernel_source;
@@ -178,7 +177,7 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
@@ -186,19 +185,23 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source += get_template_definition(
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"g1large_" + lib_name, "copy_g_nd1", in_type, out_type);
kernel_source += get_template_definition(
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
kernel_source += get_template_definition(
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
kernel_source += get_template_definition(
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
kernel_source += get_template_definition(
"gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
kernel_source += get_template_definition(
@@ -210,6 +213,38 @@ MTL::ComputePipelineState* get_copy_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_dynamic_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source += get_template_definition(
"gg1_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg2_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg3_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"ggn2_" + lib_name, "copy_gg_dynamic", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"gg1large_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type);
kernel_source += get_template_definition(
"gg3large_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type);
kernel_source += get_template_definition(
"ggn4large_" + lib_name, "copy_gg_dynamic", in_type, out_type, 4);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -219,7 +254,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
<< std::format(
softmax_kernels,
lib_name,
get_type_string(out.dtype()),
@@ -405,17 +440,17 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused()
<< fmt::format(
steel_gemm_fused_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b);
<< get_template_definition(
lib_name,
"gemm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
@@ -440,20 +475,20 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
steel_gemm_splitk_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
<< get_template_definition(
lib_name,
"gemm_splitk",
get_type_string(in.dtype()),
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b,
mn_aligned,
k_aligned);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -470,13 +505,12 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
fmt::runtime(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels),
"name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()));
<< get_template_definition(
lib_name,
axbpy ? "gemm_splitk_accum_axpby"
: "gemm_splitk_accum",
get_type_string(in.dtype()),
get_type_string(out.dtype()));
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -507,21 +541,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_masked()
<< fmt::format(
steel_gemm_masked_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outmasktype"_a = out_mask_type,
"opmasktype"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
<< get_template_definition(
lib_name,
"block_masked_gemm",
get_type_string(out.dtype()),
out_mask_type,
op_mask_type,
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b,
mn_aligned,
k_aligned);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -550,20 +584,19 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemv_masked()
<< fmt::format(
gemv_masked_kernel,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outm_t"_a = out_mask_type,
"opm_t"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"sm"_a = sm,
"sn"_a = sn,
"tm"_a = tm,
"tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1");
<< get_template_definition(
lib_name,
(transpose_mat) ? "gemv_t_masked" : "gemv_masked",
get_type_string(out.dtype()),
out_mask_type,
op_mask_type,
bm,
bn,
sm,
sn,
tm,
tn,
contiguous ? 0 : 1);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -584,17 +617,17 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format(
steel_conv_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter);
<< get_template_definition(
lib_name,
"implicit_gemm_conv_2d",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
n_channel_specialization,
small_filter);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -614,15 +647,15 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general()
<< fmt::format(
steel_conv_general_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn);
<< get_template_definition(
lib_name,
"implicit_gemm_conv_2d_general",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);

View File

@@ -1,6 +1,6 @@
// Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include <format>
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
@@ -45,6 +45,12 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& in,
const array& out);
MTL::ComputePipelineState* get_dynamic_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -212,7 +218,7 @@ get_template_definition(std::string name, std::string func, Args... args) {
};
(add_arg(args), ...);
s << ">";
return fmt::format(
return std::format(
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
name,
s.str());

View File

@@ -9,21 +9,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \

View File

@@ -7,21 +7,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_float(op) \

View File

@@ -161,3 +161,78 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
idx.y += dst_xstride;
}
}
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
}
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
}
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
}
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint3 index [[thread_position_in_grid]]) {
src += src_offset;
dst += dst_offset;
auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z},
src_shape,
src_strides,
dst_strides,
ndim);
if (N == 1) {
dst[idx.y] = src[idx.x];
return;
}
IdxT src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = src[idx.x];
idx.x += src_xstride;
idx.y += dst_xstride;
}
}

View File

@@ -4,29 +4,40 @@
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
#define instantiate_copy_same(tname, type) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) \
instantiate_kernel("gg1_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type, int) \
instantiate_kernel("gg2_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type, int) \
instantiate_kernel("gg3_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type, int) \
instantiate_kernel("ggn2_dynamic_copy" #tname, copy_gg_dynamic, type, type, 2, int) \
instantiate_kernel("gg1large_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type) \
instantiate_kernel("gg2large_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type) \
instantiate_kernel("gg3large_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type) \
instantiate_kernel("ggn4large_dynamic_copy" #tname, copy_gg_dynamic, type, type, 4)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_same(itname ##itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \

View File

@@ -1323,13 +1323,14 @@ template <typename T, int group_size, int bits, int D, bool batched>
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1374,13 +1375,14 @@ template <typename T, int group_size, int bits, bool batched>
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1425,13 +1427,14 @@ template <typename T, const int group_size, const int bits, bool batched>
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1476,13 +1479,14 @@ template <typename T, const int group_size, const int bits, bool batched>
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1527,13 +1531,14 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1706,6 +1711,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
@@ -1714,7 +1720,7 @@ template <typename T, int group_size, int bits>
lhs_indices,
rhs_indices,
y,
out_vec_size,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,
@@ -1767,6 +1773,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
@@ -1775,7 +1782,7 @@ template <typename T, int group_size, int bits>
lhs_indices,
rhs_indices,
y,
out_vec_size,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,
@@ -1828,6 +1835,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
@@ -1836,7 +1844,7 @@ template <typename T, int group_size, int bits>
lhs_indices,
rhs_indices,
y,
out_vec_size,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,

View File

@@ -53,10 +53,10 @@ instantiate_init_min_max(max, Max)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, int64_t, dim) \
@@ -67,7 +67,7 @@ instantiate_init_min_max(max, Max)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, uint, dim, bm, bn) \
itype, otype, op, int, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, int64_t, dim, bm, bn)
@@ -75,7 +75,7 @@ instantiate_init_min_max(max, Max)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, uint, dim, bm, bn) \
itype, otype, op, int, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, int64_t, dim, bm, bn)
@@ -95,7 +95,7 @@ instantiate_init_min_max(max, Max)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, int64_t, dim)
@@ -103,7 +103,7 @@ instantiate_init_min_max(max, Max)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, int64_t, dim)

View File

@@ -4,6 +4,8 @@
using namespace metal;
constant bool has_mask [[function_constant(20)]];
template <typename T, int D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
@@ -15,6 +17,9 @@ template <typename T, int D>
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -39,6 +44,9 @@ template <typename T, int D>
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
}
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
@@ -54,34 +62,39 @@ template <typename T, int D>
// For each key
for (int i = simd_gid; i < N; i += BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
if (!has_mask || mask[0]) {
// Read the key
for (int j = 0; j < elem_per_thread; j++) {
k[j] = keys[j];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Compute the i-th score
U score = 0;
for (int j = 0; j < elem_per_thread; j++) {
score += q[j] * k[j];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
// Update the output accumulator
for (int j = 0; j < elem_per_thread; j++) {
o[j] = o[j] * factor + exp_score * values[j];
}
}
// Move the pointers to the next kv
keys += stride;
values += stride;
if (has_mask) {
mask += BN * mask_seq_stride;
}
}
// Each thread has a partial part of the output so we need to combine them.
@@ -126,6 +139,9 @@ template <typename T, int D>
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -155,6 +171,10 @@ template <typename T, int D>
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_seq_stride;
}
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;
@@ -171,34 +191,39 @@ template <typename T, int D>
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
if (!has_mask || mask[0]) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
}
// Move the pointers to the next kv
keys += blocks * stride;
values += blocks * stride;
if (has_mask) {
mask += BN * blocks * mask_seq_stride;
}
}
// Each thread has a partial part of the output so we need to combine them.

View File

@@ -8,17 +8,17 @@
#include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \

View File

@@ -9,19 +9,19 @@
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_float(op) \
#define instantiate_unary_float(op) \
instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_types(op) \
#define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \

View File

@@ -91,21 +91,7 @@ struct Limits<complex64_t> {
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc(
uint elem,
constant const int* shape,
constant const int64_t* strides,
int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc(
int64_t elem,
IdxT elem,
constant const int* shape,
constant const int64_t* strides,
int ndim) {
@@ -187,9 +173,12 @@ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
constant const int64_t* c_strides,
int ndim) {
vec<IdxT, 3> loc = {
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
IdxT(elem.y * IdxT(a_strides[ndim - 2])),
IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
IdxT(elem.y * IdxT(b_strides[ndim - 2])),
IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * IdxT(a_strides[d]);

View File

@@ -47,7 +47,11 @@ std::function<void()> make_task(array arr, bool signal) {
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
try {
arr.primitive().eval_gpu(arr.inputs(), outputs);
} catch (const std::exception& error) {
abort_with_exception(error);
}
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {

View File

@@ -56,6 +56,14 @@ MTL::ComputePipelineState* get_copy_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_dynamic_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,

View File

@@ -4,11 +4,13 @@
#include <numeric>
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
@@ -25,6 +27,78 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(step, 1);
}
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
Stream s) {
auto& d = metal::device(s.device);
// Kernel to compute offset here.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.move_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
}
d.add_temporary(offset, s.index);
auto dtype = indices.dtype();
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
auto lib = d.get_library(lib_name, [dtype]() {
return std::format(
R"(
[[kernel]] void compute_dynamic_offset_{0}(
constant const {1}* indices [[buffer(0)]],
device int64_t& offset [[buffer(1)]],
constant const int64_t* strides [[buffer(2)]],
constant const int* axes [[buffer(3)]],
constant const int& n_axes [[buffer(4)]],
uint index [[thread_position_in_grid]]) {{
int64_t acc = 0;
for (int i = 0; i < n_axes; ++i) {{
acc += indices[i] * strides[axes[i]];
}}
offset = acc;
}})",
type_to_name(dtype),
get_type_string(dtype));
});
auto kernel = d.get_kernel(lib_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donate ? offset : indices, 0);
compute_encoder.set_output_array(offset, 1);
compute_encoder.set_vector_bytes(strides, 2);
compute_encoder.set_vector_bytes(axes, 3);
int n_axes = axes.size();
compute_encoder.set_bytes(n_axes, 4);
MTL::Size dims = MTL::Size(1, 1, 1);
compute_encoder.dispatch_threads(dims, dims);
return offset;
}
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -167,6 +241,10 @@ void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream());
}
@@ -211,6 +289,18 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
copy_gpu(in, out, ctype);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = out,
@@ -226,18 +316,17 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
read_task();
return;
}
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
auto signal_task = [out = out, fut = std::move(fut)]() {
auto e = Event(stream());
e.set_value(1);
encode_wait(e);
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
fut.wait();
out.event().signal();
e.signal();
};
scheduler::enqueue(io_stream(), std::move(signal_task));
auto& d = metal::device(stream().device);
d.end_encoding(stream().index);
auto command_buffer = d.get_command_buffer(stream().index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -305,26 +394,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
@@ -344,6 +414,72 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
slice_gpu(in, out, start_indices_, strides_, stream());
}
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& start = inputs[1];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto s = stream();
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
}
void DynamicSliceUpdate::eval_gpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
auto& start_indices = inputs[2];
if (upd.size() == 0) {
move_or_copy(in, out);
return;
}
// Copy or donate input to output
auto s = stream();
auto& d = metal::device(s.device);
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
auto out_offset =
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const Shape& data_shape = */ upd.shape(),
/* const Strides& i_strides = */ upd.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
}
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
@@ -359,13 +495,11 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy
@@ -381,6 +515,10 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const Stream& s = */ stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
@@ -424,7 +562,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {

View File

@@ -393,7 +393,7 @@ void row_reduce_small(
auto [in_type, out_type] = remap_reduce_types(in, op_name);
const std::string func_name = "row_reduce_small";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@@ -411,7 +411,7 @@ void row_reduce_small(
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
large ? "size_t" : "int",
n);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -490,7 +490,7 @@ void row_reduce_looped(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
const std::string func_name = "row_reduce_looped";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@@ -508,7 +508,7 @@ void row_reduce_looped(
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
large ? "size_t" : "int",
n);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -574,7 +574,7 @@ void strided_reduce_small(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
const std::string func_name = "col_reduce_small";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@@ -592,7 +592,7 @@ void strided_reduce_small(
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
large ? "size_t" : "int",
n);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -635,7 +635,7 @@ void strided_reduce_longcolumn(
}
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
Shape intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert(
@@ -665,7 +665,7 @@ void strided_reduce_longcolumn(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::string func_name = "col_reduce_longcolumn";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@@ -683,7 +683,7 @@ void strided_reduce_longcolumn(
op_name,
in_type,
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
n);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -706,7 +706,7 @@ void strided_reduce_longcolumn(
// Set the 2nd kernel
func_name = "col_reduce_looped";
kname = func_name;
large = intermediate.size() > UINT32_MAX;
large = intermediate.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@@ -718,7 +718,7 @@ void strided_reduce_longcolumn(
op_name,
intermediate.dtype(),
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
1,
32,
32);
@@ -760,7 +760,7 @@ void strided_reduce_looped(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::string func_name = "col_reduce_looped";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@@ -782,7 +782,7 @@ void strided_reduce_looped(
op_name,
in_type,
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
n,
BM,
BN);
@@ -806,7 +806,7 @@ void strided_reduce_2pass(
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
Shape intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32);
intermediate_shape.insert(
@@ -837,7 +837,7 @@ void strided_reduce_2pass(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::string func_name = "col_reduce_2pass";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@@ -859,7 +859,7 @@ void strided_reduce_2pass(
op_name,
in_type,
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
n,
BM,
BN);
@@ -882,7 +882,7 @@ void strided_reduce_2pass(
// Set the 2nd kernel
func_name = "col_reduce_looped";
kname = func_name;
large = intermediate.size() > UINT32_MAX;
large = intermediate.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@@ -894,7 +894,7 @@ void strided_reduce_2pass(
op_name,
intermediate.dtype(),
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
1,
32,
32);

View File

@@ -66,7 +66,7 @@ void RoPE::eval_gpu(
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 2;
bool with_freqs = inputs.size() == 3;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
@@ -78,7 +78,7 @@ void RoPE::eval_gpu(
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(offset_, 2);
compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size;
@@ -101,7 +101,7 @@ void RoPE::eval_gpu(
}
if (with_freqs) {
auto& freqs = inputs[1];
auto& freqs = inputs[2];
compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0];
compute_encoder.set_bytes(freq_stride, 11);

View File

@@ -1,6 +1,5 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <sstream>
#include "mlx/backend/common/compiled.h"
@@ -116,7 +115,8 @@ void sdpa_vector(
const array& k,
const array& v,
array& out,
float scale) {
float scale,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -134,9 +134,16 @@ void sdpa_vector(
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1);
bool has_mask = mask.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
@@ -149,6 +156,14 @@ void sdpa_vector(
compute_encoder.set_bytes(k_stride, 6);
compute_encoder.set_bytes(v_stride, 7);
compute_encoder.set_bytes(scale, 8);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 9);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 10);
compute_encoder.set_bytes(head_stride, 11);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -161,7 +176,8 @@ void sdpa_vector_2pass(
const array& k,
const array& v,
array& out,
float scale) {
float scale,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -198,9 +214,17 @@ void sdpa_vector_2pass(
d.add_temporary(sums, s.index);
d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
@@ -215,6 +239,14 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(k_stride, 8);
compute_encoder.set_bytes(v_stride, 9);
compute_encoder.set_bytes(scale, 10);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 12);
compute_encoder.set_bytes(head_stride, 13);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -247,8 +279,6 @@ void sdpa_vector_2pass(
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() == 3);
auto& s = stream();
auto& d = metal::device(s.device);
@@ -296,6 +326,8 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
const auto& q = copy_unless(is_contiguous, q_pre);
// 1, heads, seq_len, head_dim
// mask [1, query_heads, 1, seq_len]
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
@@ -306,15 +338,18 @@ void ScaledDotProductAttention::eval_gpu(
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}
auto mask =
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
// We route to the 2 pass fused attention if
// - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa
char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask);
} else {
sdpa_vector(s, d, q, k, v, o, scale_);
sdpa_vector(s, d, q, k, v, o, scale_, mask);
}
}

View File

@@ -14,35 +14,18 @@ void slice_gpu(
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
// Calculate out strides, initial offset and if copy needs to be made
// Calculate out strides and initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
auto copy_needed =
std::any_of(strides.begin(), strides.end(), [](auto i) { return i < 0; });
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_gpu_inplace(
/* const array& in = */ in,
/* array& out = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General,
/* const Stream& s = */ s);
} else {
size_t data_end = 1;
for (int i = 0; i < strides.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_end = 1;
for (int i = 0; i < strides.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_size = data_end - data_offset;
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}
size_t data_size = data_end - data_offset;
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}
void concatenate_gpu(
@@ -80,8 +63,8 @@ void pad_gpu(
const array& in,
const array& val,
array& out,
std::vector<int> axes,
std::vector<int> low_pad_size,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);

View File

@@ -23,8 +23,8 @@ void pad_gpu(
const array& in,
const array& val,
array& out,
std::vector<int> axes,
std::vector<int> low_pad_size,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s);
} // namespace mlx::core

View File

@@ -82,9 +82,17 @@ void single_block_sort(
compute_encoder.set_bytes(out_stride_segment_axis, 6);
} else {
compute_encoder.set_bytes(nc_dim, 5);
compute_encoder.set_vector_bytes(nc_shape, 6);
compute_encoder.set_vector_bytes(in_nc_str, 7);
compute_encoder.set_vector_bytes(out_nc_str, 8);
if (nc_shape.empty()) {
int shape = 0;
int64_t stride = 0;
compute_encoder.set_bytes(shape, 6);
compute_encoder.set_bytes(stride, 7);
compute_encoder.set_bytes(stride, 8);
} else {
compute_encoder.set_vector_bytes(nc_shape, 6);
compute_encoder.set_vector_bytes(in_nc_str, 7);
compute_encoder.set_vector_bytes(out_nc_str, 8);
}
}
MTL::Size group_dims = MTL::Size(bn, 1, 1);

View File

@@ -36,15 +36,15 @@ void ternary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
bool large = out.data_size() > UINT_MAX;
bool large;
auto ndim = shape.size();
int work_per_thread;
if (topt == TernaryOpType::General) {
large |=
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
c.data_size() > UINT32_MAX);
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.size() > INT32_MAX;
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > INT32_MAX;
work_per_thread = 1;
}
std::string kernel_name;

View File

@@ -36,9 +36,11 @@ void unary_op_gpu_inplace(
auto [shape, strides] = maybe_collapse();
int ndim = shape.size();
size_t nthreads = contig ? in.data_size() : in.size();
bool large = in.data_size() > UINT32_MAX;
bool large;
if (!contig) {
large |= in.size() > UINT32_MAX;
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
int work_per_thread = !contig && large ? 4 : 1;
std::string kernel_name;

View File

@@ -35,6 +35,7 @@ NO_CPU(AsStrided)
NO_CPU(BitwiseBinary)
NO_CPU(BlockMaskedMM)
NO_CPU(Broadcast)
NO_CPU(BroadcastAxes)
NO_CPU(Ceil)
NO_CPU(Cholesky)
NO_CPU(Concatenate)
@@ -48,6 +49,8 @@ NO_CPU_MULTI(CustomTransforms)
NO_CPU_MULTI(Depends)
NO_CPU(Divide)
NO_CPU_MULTI(DivMod)
NO_CPU(DynamicSlice)
NO_CPU(DynamicSliceUpdate)
NO_CPU(NumberOfElements)
NO_CPU(Remainder)
NO_CPU_MULTI(Eigh)
@@ -55,8 +58,10 @@ NO_CPU(Equal)
NO_CPU(Erf)
NO_CPU(ErfInv)
NO_CPU(Exp)
NO_CPU(ExpandDims)
NO_CPU(Expm1)
NO_CPU(FFT)
NO_CPU(Flatten)
NO_CPU(Floor)
NO_CPU(Full)
NO_CPU(Gather)
@@ -104,6 +109,7 @@ NO_CPU(Softmax)
NO_CPU(Sort)
NO_CPU_MULTI(Split)
NO_CPU(Square)
NO_CPU(Squeeze)
NO_CPU(Sqrt)
NO_CPU(StopGradient)
NO_CPU(Subtract)
@@ -111,6 +117,7 @@ NO_CPU_MULTI(SVD)
NO_CPU(Tan)
NO_CPU(Tanh)
NO_CPU(Transpose)
NO_CPU(Unflatten)
NO_CPU(Inverse)
NO_CPU(View)

View File

@@ -36,6 +36,7 @@ NO_GPU(AsStrided)
NO_GPU(BitwiseBinary)
NO_GPU(BlockMaskedMM)
NO_GPU(Broadcast)
NO_GPU(BroadcastAxes)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Concatenate)
@@ -49,14 +50,18 @@ NO_GPU_MULTI(CustomTransforms)
NO_GPU_MULTI(Depends)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(NumberOfElements)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(ExpandDims)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Flatten)
NO_GPU(Floor)
NO_GPU(Full)
NO_GPU(Gather)
@@ -104,6 +109,7 @@ NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU_MULTI(Split)
NO_GPU(Square)
NO_GPU(Squeeze)
NO_GPU(Sqrt)
NO_GPU(StopGradient)
NO_GPU(Subtract)
@@ -111,6 +117,7 @@ NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU(Unflatten)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)

View File

@@ -68,24 +68,7 @@ bool is_reduction(const Primitive& p) {
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
is_noop(p);
}
bool allows_shapeless(const Primitive& p) {
return typeid(p) == typeid(Arange) || typeid(p) == typeid(Compiled) ||
is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) ||
typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) ||
typeid(p) == typeid(ArgSort) || typeid(p) == typeid(ArgPartition) ||
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
typeid(p) == typeid(Reshape) || typeid(p) == typeid(Matmul) ||
typeid(p) == typeid(QuantizedMatmul) ||
typeid(p) == typeid(fast::AffineQuantize) ||
typeid(p) == typeid(fast::LayerNorm) ||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
typeid(p) == typeid(fast::ScaledDotProductAttention);
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p);
}
Compiled::Compiled(
@@ -172,9 +155,6 @@ CompileMode& compile_mode() {
return compile_mode_;
}
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Helper like below but only merges the two provided arrays. If the src has
// siblings then these won't be merged to the dst.
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
@@ -303,9 +283,10 @@ CompilerCache& compiler_cache() {
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs) {
const std::vector<array>& inputs,
bool shapeless) {
// Set the global tracing flag.
detail::InTracing in_tracing;
detail::InTracing in_tracing{shapeless};
// Run the function on placeholder inputs
// to get compute graph
@@ -369,12 +350,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
return {tape, parents_map};
}
// Simplify the tape. Note, this function modifies in-place both the tape and
// the parents map to remove orphaned arrays
// Simplify the tape. Note, this function modifies in-place both the tape,
// the parents map to remove orphaned arrays, and potentially the outputs
void compile_simplify(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& outputs,
std::vector<array>& outputs,
int passes) {
// Helpers to identify identical scalars
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
@@ -451,6 +432,28 @@ void compile_simplify(
}
tape = std::move(new_tape);
// Remove no-ops
{
std::unordered_map<uintptr_t, array> output_map;
for (auto& o : outputs) {
output_map.insert({o.id(), o});
}
for (auto& arr : tape) {
if (!arr.has_primitive() || !is_noop(arr.primitive())) {
new_tape.push_back(std::move(arr));
continue;
}
merge_one(arr.inputs()[0], arr, parents_map);
if (auto it = output_map.find(arr.id()); it != output_map.end()) {
it->second = arr.inputs()[0];
}
}
tape = std::move(new_tape);
for (auto& o : outputs) {
o = output_map.at(o.id());
}
}
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
for (uint32_t i = 0; i < tape.size(); ++i) {
tape_order.insert({tape[i].id(), i});
@@ -460,6 +463,7 @@ void compile_simplify(
for (auto& o : outputs) {
output_set.insert(o.id());
}
// Multi-pass merge only keeping non-orphaned arrays in the tape
for (int pass = 0; pass < passes; ++pass) {
for (auto& arr : tape) {
@@ -748,10 +752,15 @@ std::vector<array> compile_replace(
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}
auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };
for (auto& a : tape) {
// Arrays in the tape without primitives are constants
// and can be used directly
if (!a.has_primitive()) {
// Arrays in the tape without primitives are either:
// - inputs, which are already in the map
// - constants, which can be used directly
// - a load primitive which has no inputs and will become a constant
// after the first eval
if (!a.has_primitive() || is_load(a.primitive())) {
trace_to_real.insert({a.id(), a});
} else {
// Find real inputs
@@ -799,24 +808,6 @@ std::vector<array> compile_replace(
return outputs;
}
void compile_validate_shapeless(const std::vector<array>& tape) {
for (auto& t : tape) {
if (!t.has_primitive()) {
continue;
}
auto& p = t.primitive();
if (allows_shapeless(p)) {
continue;
}
std::ostringstream msg;
msg << "[compile] Cannot compile primitive ";
p.print(msg);
msg << " with shapeless enabled.";
throw std::invalid_argument(msg.str());
}
}
bool skip_compile() {
return compile_mode() == CompileMode::disabled ||
!(compile_available_for_device(default_device()));
@@ -856,7 +847,8 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// Set the constants
entry.constants = std::move(constants);
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
std::tie(entry.inputs, entry.outputs) =
compile_trace(fun, inputs, shapeless);
// DFS the graph and get a tape, and a map of array id to (parent,
// position in parent inputs)
@@ -876,10 +868,6 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
if (compile_mode() != CompileMode::no_fuse) {
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
}
if (shapeless) {
compile_validate_shapeless(entry.tape);
}
}
// At this point we must have a tape, now replace the placeholders

View File

@@ -2,7 +2,9 @@
#pragma once
#include "mlx/device.h"
#include <unordered_map>
#include "mlx/array.h"
namespace mlx::core::detail {
@@ -22,4 +24,35 @@ void compile_erase(std::uintptr_t fun_id);
void compile_clear_cache();
bool compile_available_for_device(const Device& device);
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs,
bool shapeless);
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Traverses the graph to build a tape and a map of array ids to their parents
std::pair<std::vector<array>, ParentsMap> compile_dfs(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& original_inputs);
// Simplify the tape.
void compile_simplify(
std::vector<array>& tape,
ParentsMap& parents_map,
std::vector<array>& outputs,
int passes);
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs,
bool shapeless);
void compile_validate_shapeless(const std::vector<array>& tape);
} // namespace mlx::core::detail

Some files were not shown because too many files have changed in this diff Show More