Compare commits

..

14 Commits

Author SHA1 Message Date
Ronan Collobert
87b680766e Gloo backend support 2024-11-13 13:52:37 -08:00
Ronan Collobert
70ffaa50d2 be more relaxed on OpenMPI version 2024-11-13 13:51:37 -08:00
Angelos Katharopoulos
d82699f0f1 Merge branch 'distributed-layers' into socket-distributed-layers 2024-11-05 11:36:16 -08:00
Angelos Katharopoulos
6fc00d2c10 Add rudimentary barrier 2024-11-05 11:34:55 -08:00
Angelos Katharopoulos
44f0de2854 Fix run without distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
29ec3539ed TCP socket distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e94f0028c3 Change the send message size 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e5354fcddb Make it work even for donated inputs 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
34dd079a64 Start a sockets based distributed backend 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
16975815e9 Fixes in distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
a8b3da7946 Add distributed layers to nn top-level 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
060e1c9f92 Add quantized distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
0b04742985 Add the distributed linear layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
c3ccd4919f Add MPI barrier 2024-11-05 11:26:53 -08:00
1911 changed files with 26345 additions and 657962 deletions

View File

@@ -7,9 +7,15 @@ parameters:
nightly_build:
type: boolean
default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
@@ -18,14 +24,13 @@ jobs:
type: boolean
default: false
macos:
xcode: "26.0.0"
resource_class: m4pro.medium
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9
brew install doxygen
python3.9 -m venv env
@@ -33,7 +38,7 @@ jobs:
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
pip install . -v
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
@@ -65,9 +70,9 @@ jobs:
git push -f origin gh-pages
linux_build_and_test:
machine:
image: ubuntu-2204:current
resource_class: large
docker:
- image: cimg/python:3.9
steps:
- checkout
- run:
@@ -79,36 +84,33 @@ jobs:
- run:
name: Install dependencies
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install numpy
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
python3 -m unittest discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
@@ -120,64 +122,57 @@ jobs:
parameters:
xcode_version:
type: string
default: "26.0.0"
macosx_deployment_target:
type: string
default: ""
default: "15.2.0"
macos:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m4pro.medium
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
command: |
uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build example extension
command: |
source .venv/bin/activate
source env/bin/activate
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source .venv/bin/activate
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
@@ -186,7 +181,7 @@ jobs:
- run:
name: Build small binary
command: |
source .venv/bin/activate
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
@@ -198,74 +193,13 @@ jobs:
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
build_release:
parameters:
@@ -274,34 +208,25 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "26.0.0"
default: "15.2.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m4pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install nanobind==2.2.0
pip install --upgrade setuptools
pip install numpy
pip install twine
@@ -309,38 +234,30 @@ jobs:
- run:
name: Install Python package
command: |
conda activate env
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
source env/bin/activate
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
name: Generate package stubs
command: |
conda activate env
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
conda activate env
source env/bin/activate
twine upload dist/*
- store_artifacts:
path: dist/
@@ -350,100 +267,52 @@ jobs:
python_version:
type: string
default: "3.9"
build_env:
extra_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.build_env >> pip install ".[dev]" -v
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: xlarge
steps:
- checkout
- run:
name: Build wheel
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
@@ -455,23 +324,21 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
@@ -483,9 +350,8 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_documentation:
filters:
tags:
@@ -493,25 +359,6 @@ workflows:
branches:
ignore: /.*/
upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb:
when:
@@ -527,14 +374,9 @@ workflows:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
@@ -545,33 +387,27 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release
build_dev_release:
xcode_version: ["15.0.0", "15.2.0"]
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
- << pipeline.parameters.weekly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]
extra_env: ["PYPI_RELEASE=1"]

4
.gitignore vendored
View File

@@ -36,7 +36,6 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
uv.lock
# vim
*.swp
@@ -77,9 +76,6 @@ build/
*.out
*.app
# Debug symbols
*.pdb
# VSCode
.vscode/
.DS_Store

View File

@@ -1,16 +1,15 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
rev: v18.1.8
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.1.0
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.0
rev: 5.13.2
hooks:
- id: isort
args:

View File

@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
@@ -19,17 +19,11 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software
MLX leverages several third-party software, listed here together with

View File

@@ -1,24 +1,6 @@
cmake_minimum_required(VERSION 3.25)
cmake_minimum_required(VERSION 3.24)
if(NOT MLX_VERSION)
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
set(_major ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
set(_minor ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION})
endif()
project(
mlx
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
project(mlx LANGUAGES C CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
@@ -34,23 +16,26 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.19.3)
endif()
# --------------------- Processor tests -------------------------
message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
set(MLX_BUILD_ARM OFF)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
@@ -66,17 +51,14 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
endif()
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
# ----------------------------- Lib -----------------------------
@@ -93,10 +75,6 @@ if(MLX_BUILD_METAL)
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
@@ -111,26 +89,25 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
if(${MACOS_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
# Get the metal version
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
@@ -138,64 +115,20 @@ elseif(MLX_BUILD_METAL)
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
set(MLX_BUILD_GGUF OFF)
# There is no prebuilt OpenBLAS distribution for MSVC.
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
endif()
# Windows implementation of dlfcn.h APIs.
FetchContent_Declare(
dlfcn-win32
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.1
EXCLUDE_FROM_ALL)
block()
set(BUILD_SHARED_LIBS OFF)
FetchContent_MakeAvailable(dlfcn-win32)
endblock()
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
target_link_libraries(mlx PRIVATE dl)
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif()
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(MLX_USE_ACCELERATE)
add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code.
FetchContent_Declare(
openblas
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
GIT_TAG v0.3.28
EXCLUDE_FROM_ALL)
set(BUILD_STATIC_LIBS ON) # link statically
set(NOFORTRAN ON) # msvc has no fortran compiler
FetchContent_MakeAvailable(openblas)
target_link_libraries(mlx PRIVATE openblas)
target_include_directories(
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
else()
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
@@ -213,7 +146,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
@@ -226,19 +159,30 @@ if(MLX_BUILD_CPU)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
endif()
else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "${MPIEXEC_EXECUTABLE} --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*" OR ${MPI_VERSION} MATCHES ".*OpenRTE.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
target_link_libraries(mlx PRIVATE ${MPI_CXX_LIBRARIES})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
@@ -246,19 +190,12 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)
@@ -270,7 +207,8 @@ if(MLX_BUILD_PYTHON_BINDINGS)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE nanobind_ROOT)
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()

View File

@@ -17,11 +17,11 @@ possible.
You can also run the formatters manually as follows:
```shell
```
clang-format -i file.cpp
```
```shell
```
black file.py
```

View File

@@ -1,6 +1,4 @@
include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ *
include cmake/*
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -68,23 +68,18 @@ in the documentation.
## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
```bash
**With `pip`**:
```
pip install mlx
```
To install the CUDA backend on Linux, run:
**With `conda`**:
```bash
pip install mlx[cuda]
```
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
conda install -c conda-forge mlx
```
Checkout the

View File

@@ -5,35 +5,35 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_value_and_grad() {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
auto x = ones({200, 1000});
eval(x);
auto fn = [](array x) {
for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x));
x = log(exp(x));
}
return mx::sum(x);
return sum(x);
};
auto grad_fn = mx::grad(fn);
auto grad_fn = grad(fn);
auto independent_value_and_grad = [&]() {
auto value = fn(x);
auto dfdx = grad_fn(x);
return std::vector<mx::array>{value, dfdx};
return std::vector<array>{value, dfdx};
};
TIME(independent_value_and_grad);
auto value_and_grad_fn = mx::value_and_grad(fn);
auto value_and_grad_fn = value_and_grad(fn);
auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<mx::array>{value, dfdx};
return std::vector<array>{value, dfdx};
};
TIME(combined_value_and_grad);
}
int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
std::cout << "Benchmarks for " << default_device() << std::endl;
time_value_and_grad();
}

View File

@@ -4,21 +4,21 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_add_op() {
std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back());
}
set_default_device(mx::Device::cpu);
set_default_device(Device::cpu);
for (auto size : sizes) {
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
std::cout << "Size " << size << std::endl;
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
TIMEM("cpu", add, a, b, Device::cpu);
TIMEM("gpu", add, a, b, Device::gpu);
}
}

View File

@@ -1,111 +1,110 @@
// Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream>
#include <sstream>
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_irregular_binary_ops_1D() {
auto device = mx::default_device();
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", mx::add, a, b, device);
TIMEM("1D strided", add, a, b, device);
}
void time_irregular_binary_ops_2D() {
auto device = mx::default_device();
auto device = default_device();
int size = 2048;
auto a = mx::random::uniform({size, size});
auto b = mx::random::uniform({size, size});
mx::eval(a, b);
TIMEM("2D regular", mx::add, a, b, device);
auto a = random::uniform({size, size});
auto b = random::uniform({size, size});
eval(a, b);
TIMEM("2D regular", add, a, b, device);
b = mx::transpose(b);
mx::eval(b);
TIMEM("2D mx::transpose", mx::add, a, b, device);
b = transpose(b);
eval(b);
TIMEM("2D transpose", add, a, b, device);
b = mx::random::uniform({size});
mx::eval(b);
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({size});
eval(b);
TIMEM("2D broadcast dim 0", add, a, b, device);
b = mx::reshape(b, {size, 1});
mx::eval(b);
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
b = reshape(b, {size, 1});
eval(b);
TIMEM("2D broadcast dim 1", add, a, b, device);
}
void time_irregular_binary_ops_3D() {
auto device = mx::default_device();
auto device = default_device();
int d0 = 32;
int d1 = 512;
int d2 = 512;
auto a = mx::random::uniform({d0, d1, d2});
auto b = mx::random::uniform({d0, d1, d2});
TIMEM("3D regular", mx::add, a, b, device);
auto a = random::uniform({d0, d1, d2});
auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", add, a, b, device);
b = mx::transpose(b, {0, 2, 1});
TIMEM("3D mx::transpose", mx::add, a, b, device);
b = transpose(b, {0, 2, 1});
TIMEM("3D transpose", add, a, b, device);
b = mx::random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", add, a, b, device);
b = mx::random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", add, a, b, device);
b = mx::random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
b = mx::random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
b = mx::random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
}
void time_irregular_binary_ops_4D() {
auto device = mx::default_device();
auto device = default_device();
std::vector<int> shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
auto a = random::uniform(shape);
auto b = random::uniform(shape);
TIMEM("4D regular", mx::add, a, b, device);
TIMEM("4D regular", add, a, b, device);
b = mx::transpose(b, {0, 1, 3, 2});
TIMEM("4D mx::transpose", mx::add, a, b, device);
b = transpose(b, {0, 1, 3, 2});
TIMEM("4D transpose", add, a, b, device);
std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1;
b = mx::random::uniform(shape);
b = random::uniform(shape);
std::ostringstream msg;
msg << om << i;
TIMEM(msg.str(), mx::add, a, b, device);
TIMEM(msg.str(), add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1;
std::ostringstream msg;
msg << om << i << ", " << j;
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1;
std::ostringstream msg;
msg << om << i << ", " << j << ", " << k;
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[k] = a.shape(k);
}
}
@@ -114,83 +113,83 @@ void time_irregular_binary_ops_4D() {
}
void time_irregular_reshape() {
auto device = mx::default_device();
auto device = default_device();
std::vector<int> shape;
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
auto reshape_fn = [&shape, device](const array& a) {
return reshape(a, shape, device);
};
int size = 64;
int d = 2 * size;
auto a = mx::random::uniform({d, d, d});
auto a = random::uniform({d, d, d});
shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a);
a = mx::transpose(a);
a = transpose(a);
shape = {8 * size, size, size};
TIMEM("3D mx::transpose", reshape_fn, a);
TIMEM("3D transpose", reshape_fn, a);
a = mx::transpose(a, {1, 2, 0});
a = transpose(a, {1, 2, 0});
shape = {8 * size, size, size};
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
TIMEM("3D transpose dims 1 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
a = broadcast_to(random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
a = broadcast_to(random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
}
void time_irregular_astype_1D() {
auto device = mx::default_device();
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = mx::random::uniform({size});
auto a = random::uniform({size});
a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", mx::astype, a, mx::int32, device);
TIMEM("1D strided", astype, a, int32, device);
}
void time_irregular_astype_2D() {
auto device = mx::default_device();
auto device = default_device();
int size = 2048;
std::vector<int> shape = {size, size};
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);
auto a = random::uniform(shape);
TIMEM("2D regular", astype, a, int32, device);
a = mx::transpose(a);
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
a = transpose(a);
TIMEM("2D transpose", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", astype, a, int32, device);
}
int main(int argc, char** argv) {
if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
set_default_device(use_gpu ? Device::gpu : Device::cpu);
}
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
std::cout << "Benchmarks for " << default_device() << std::endl;
time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D();

View File

@@ -3,20 +3,20 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_creation_ops() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
auto full_fp32 = [&]() { return full(shape, 3.3f); };
TIME(full_fp32);
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
TIME(zeros_fp32);
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
auto ones_fp32 = [&]() { return ones(shape, float32); };
TIME(ones_fp32);
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32);
}
@@ -24,212 +24,194 @@ void time_type_conversions() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto device = mx::default_device();
auto device = default_device();
auto a = mx::zeros(shape, mx::float32);
mx::eval(a);
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
auto a = zeros(shape, float32);
eval(a);
TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("float32 to uint32", astype, a, uint32, device);
a = mx::zeros(shape, mx::int32);
mx::eval(a);
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
a = zeros(shape, int32);
eval(a);
TIMEM("int32 to float32", astype, a, float32, device);
a = mx::zeros(shape, mx::bool_);
mx::eval(a);
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
a = zeros(shape, bool_);
eval(a);
TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to uint32", astype, a, uint32, device);
}
void time_random_generation() {
int M = 2000;
int N = 500;
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
auto uniform = [&]() { return random::uniform({M, N}, float32); };
TIME(uniform);
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
auto normal = [&]() { return random::normal({M, N}, float32); };
TIME(normal);
}
void time_unary_ops() {
int M = 2000;
int N = 500;
auto device = mx::default_device();
auto device = default_device();
auto a = mx::random::normal({M, N});
mx::eval(a);
auto a = random::normal({M, N});
eval(a);
TIME(mlx::core::abs, a, device);
TIME(mx::negative, a, device);
TIME(mx::sign, a, device);
TIME(mx::square, a, device);
TIME(negative, a, device);
TIME(sign, a, device);
TIME(square, a, device);
TIME(mlx::core::sqrt, a, device);
TIME(mx::rsqrt, a, device);
TIME(rsqrt, a, device);
TIME(mlx::core::exp, a, device);
a = mx::random::uniform({M, N});
a = random::uniform({M, N});
TIME(mlx::core::log, a, device);
}
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = mx::random::randint(0, 2, {M, N, K});
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(mx::add, a, b, device);
TIME(mx::subtract, a, b, device);
TIME(mx::multiply, a, b, device);
TIME(mx::divide, a, b, device);
TIME(mx::maximum, a, b, device);
TIME(mx::minimum, a, b, device);
TIME(mx::where, condition, a, b, device);
TIME(add, a, b, device);
TIME(subtract, a, b, device);
TIME(multiply, a, b, device);
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
condition = mx::array({true});
b = mx::random::uniform({1});
mx::eval(b);
TIMEM("scalar", mx::add, a, b, device);
TIMEM("vector-scalar", mx::subtract, a, b, device);
TIMEM("scalar-vector", mx::subtract, b, a, device);
TIMEM("scalar", mx::multiply, a, b, device);
TIMEM("vector-scalar", mx::divide, a, b, device);
TIMEM("scalar-vector", mx::divide, b, a, device);
TIMEM("scalar-vector", mx::where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
mx::eval(a, b);
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}
void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50;
auto a = mx::random::uniform({M, N, O, P});
auto b = mx::random::uniform({M, N, O, P});
auto device = mx::default_device();
mx::eval(a, b);
TIMEM("non-strided", mx::add, a, b, device);
a = mx::transpose(a, {1, 0, 2, 3});
b = mx::transpose(b, {3, 2, 0, 1});
mx::eval(a, b);
TIMEM("strided", mx::add, a, b, device);
auto a = random::uniform({M, N, O, P});
auto b = random::uniform({M, N, O, P});
auto device = default_device();
eval(a, b);
TIMEM("non-strided", add, a, b, device);
a = transpose(a, {1, 0, 2, 3});
b = transpose(b, {3, 2, 0, 1});
eval(a, b);
TIMEM("strided", add, a, b, device);
}
void time_comparisons() {
int M = 1000, N = 100, K = 10;
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::equal, a, b, device);
TIME(mx::greater, a, b, device);
TIME(mx::greater_equal, a, b, device);
TIME(mx::less, a, b, device);
TIME(mx::less_equal, a, b, device);
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(equal, a, b, device);
TIME(greater, a, b, device);
TIME(greater_equal, a, b, device);
TIME(less, a, b, device);
TIME(less_equal, a, b, device);
}
void time_matvec() {
int M = 2000, N = 200;
auto a = mx::random::uniform({M, N});
auto b = mx::random::uniform({N});
auto c = mx::random::uniform({M});
mx::eval(a, b, c);
auto matvec = [&]() { return mx::matmul(a, b); };
auto a = random::uniform({M, N});
auto b = random::uniform({N});
auto c = random::uniform({M});
eval(a, b, c);
auto matvec = [&]() { return matmul(a, b); };
TIME(matvec);
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
TIME(matvec_transpose);
}
void time_matmul() {
int M = 1000, N = 1000, K = 1000;
auto a = mx::random::uniform({M, K});
auto b = mx::random::uniform({K, N});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::matmul, a, b, device);
auto a = random::uniform({M, K});
auto b = random::uniform({K, N});
auto device = default_device();
eval(a, b);
TIME(matmul, a, b, device);
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
TIME(transpose_matmul);
}
void time_reductions() {
auto a = mx::random::normal({10000, 1000});
mx::eval(a);
auto sum_all = [&a]() { return mx::sum(a, false); };
auto a = random::normal({10000, 1000});
eval(a);
auto sum_all = [&a]() { return sum(a, false); };
TIME(sum_all);
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
TIME(sum_along_0);
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
TIME(sum_along_1);
auto prod_all = [&a]() { return mx::prod(a, false); };
auto prod_all = [&a]() { return prod(a, false); };
TIME(prod_all);
auto all_true = [&a]() { return mx::all(a, false); };
auto all_true = [&a]() { return all(a, false); };
TIME(all_true);
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
auto all_along_0 = [&a]() { return all(a, 0, false); };
TIME(all_along_0);
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
auto all_along_1 = [&a]() { return all(a, 1, false); };
TIME(all_along_1);
auto any_true = [&a]() { return mx::any(a, false); };
auto any_true = [&a]() { return any(a, false); };
TIME(any_true);
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
}
void time_gather_scatter() {
auto a = mx::random::normal({1000, 768});
mx::eval(a);
auto indices = mx::random::randint(0, 1000, {256});
mx::eval(indices);
auto a = random::normal({1000, 768});
eval(a);
auto indices = random::randint(0, 1000, {256});
eval(indices);
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
TIME(embedding_lookup);
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
mx::eval(indices);
indices = random::randint(0, 768 * 1000, {256 * 768});
eval(indices);
auto single_element_lookup = [&a, &indices]() {
return mx::take(a, indices);
};
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
TIME(single_element_lookup);
indices = mx::random::randint(0, 1000, {256});
auto updates = mx::random::normal({256, 1, 768});
mx::eval(indices, updates);
indices = random::randint(0, 1000, {256});
auto updates = random::normal({256, 1, 768});
eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -241,10 +223,10 @@ void time_gather_scatter() {
};
TIME(embedding_add);
a = mx::reshape(a, {-1});
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
updates = mx::random::normal({256 * 768, 1});
mx::eval(a, indices, updates);
a = reshape(a, {-1});
indices = random::randint(0, 768 * 1000, {768 * 256});
updates = random::normal({256 * 768, 1});
eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@@ -258,21 +240,21 @@ void time_gather_scatter() {
}
void time_divmod() {
auto a = mx::random::normal({1000});
auto b = mx::random::normal({1000});
mx::eval({a, b});
auto a = random::normal({1000});
auto b = random::normal({1000});
eval({a, b});
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops();
time_type_conversions();
time_unary_ops();

View File

@@ -5,7 +5,6 @@ import os
import time
import torch
import torch.cuda
import torch.mps
@@ -45,10 +44,8 @@ def bench(f, *args):
def sync_if_needed(x):
if x.device == torch.device("mps"):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad()
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad()
def softmax(axis, x):
ys = []
@@ -351,11 +340,7 @@ if __name__ == "__main__":
args.axis.pop(0)
torch.set_num_threads(1)
device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
device = "cpu" if args.cpu else "mps"
types = args.dtype
if not types:
@@ -475,8 +460,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -1,107 +0,0 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch

View File

@@ -1,74 +0,0 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -1,84 +0,0 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -1,7 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
@@ -12,71 +10,32 @@ def layer_norm(x, w, b, eps):
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
y = (x - mu) * mx.rsqrt(v + eps)
if w is not None:
y = y * w
if b is not None:
y = y + b
return y
return (x - mu) * mx.rsqrt(v + eps) * w + b
def time_layer_norm(N, dt):
L = 1024
def time_layer_norm():
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(f, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
def layer_norm_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_grad_loop, g1, x, w, b)
time_fn(layer_norm_grad_loop, g2, x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_grad_x_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_grad_x_loop, g1, x)
time_fn(layer_norm_grad_x_loop, g2, x)
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
if __name__ == "__main__":
for dt in [mx.float32, mx.float16, mx.bfloat16]:
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)
time_layer_norm()

View File

@@ -9,10 +9,7 @@ def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
y = (x * n).astype(ot)
if w is not None:
y = y * w
return y
return (x * n).astype(ot) * w
def time_rms_norm():
@@ -37,27 +34,6 @@ def time_rms_norm():
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(rms_norm_loop, g1, x)
time_fn(rms_norm_loop, g2, x)
time_fn(rms_norm_loop, mx.compile(g1), x)
time_fn(rms_norm_loop, mx.compile(g2), x)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -1,223 +1,62 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
from time_utils import time_fn
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
MAX_SEQ = 300
START_SEQ = 100
SEQ_INCREMENT = 50
def bench(f, *args):
for i in range(N_warmup):
f(*args)
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(*args)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def sdpa_primitives(qs, ks, vs, alpha):
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ vs
return o
time_fn(sdpa_primitives, q, k, v, scale)
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
np_dtype = getattr(np, dtype)
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if mask is not None:
if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_fused_attn(q, k, v, scale, mask):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose:
q_t = mx.transpose(q, (0, 2, 1, 3))
k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
q_out = q
for i in range(N_iter_func):
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
mx.eval(q_out)
return q_out
def bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
):
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
)
time_mlx_unfused = bench(
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
time_mlx_fused = bench(
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
o_mlx_unfused = do_attention(
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
atol = 1e-5 if dtype == "float32" else 2e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
time_fn(sdpa_fused, q, k, v, scale)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 8),
( 1, 2048, 2048, 64, 32, 8),
( 1, 4096, 4096, 64, 32, 8),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 8),
( 1, 2048, 2048, 80, 32, 8),
( 1, 4096, 4096, 80, 32, 8),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 8),
( 1, 2048, 2048, 128, 32, 8),
( 1, 4096, 4096, 128, 32, 8),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
masks = [None, "bool", "causal"]
print(
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
)
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
for mask_in in masks:
time_mlx_fused, time_mlx_unfused = bench_shape(
B,
qsl,
ksl,
head_dim,
n_q_heads,
n_kv_heads,
dtype,
transpose,
mask_in,
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@@ -4,92 +4,46 @@ import math
import mlx.core as mx
from time_utils import time_fn
L = 16384
L = 1024
H = 32
H_k = H // 4
H_k = 32 // 4
D = 128
V = 128
dtype = mx.float16
loops = 10
def upproject(x, w):
if w is None:
return x
else:
return x @ w.T
def attention(q, k, v, mask=None, w=None):
def _sdpa(q, k, v):
def attention(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
_, _, _, V = v.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
if mask is not None:
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
s = mx.where(m, s, mx.finfo(s.dtype).min)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, V)
for i in range(loops):
q = _sdpa(q, k, v)
q = upproject(q, w)
return q
return o.reshape(B, Hq, L, D)
def sdpa(q, k, v, mask=None, w=None):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q = upproject(q, w)
return q
def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mx.eval(q, k, v, w)
time_fn(attention, q, k, v, w=w)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(attention, q, k, v)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mx.eval(q, k, v, w)
time_fn(sdpa, q, k, v, w=w)
def time_self_attention_sdpa_with_mask():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mask = mx.full((L,), True)
mask[L // 2 :] = False
mx.eval(q, k, v, mask, w)
def sdpa_mask(*args):
return sdpa(*args, mask=mask, w=w)
def attention_mask(*args):
return attention(*args, mask=mask, w=w)
time_fn(attention_mask, q, k, v)
time_fn(sdpa_mask, q, k, v)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)
if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()
time_self_attention_sdpa_with_mask()

View File

@@ -51,20 +51,6 @@ def time_maximum():
time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -122,8 +108,6 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_min()
time_max()
time_maximum()
time_exp()
time_negative()

View File

@@ -1,55 +0,0 @@
import time
import mlx.core as mx
rank = mx.distributed.init().rank()
def timeit(fn, a):
# warmup
for _ in range(5):
mx.eval(fn(a))
its = 10
tic = time.perf_counter()
for _ in range(its):
mx.eval(fn(a))
toc = time.perf_counter()
ms = 1000 * (toc - tic) / its
return ms
def all_reduce_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_sum(x)
x = x - 1
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
def all_gather_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_gather(x)[0]
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All gather: time per iteration {ms:.6f} (ms)")
if __name__ == "__main__":
all_reduce_benchmark()
all_gather_benchmark()

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -1,7 +1,5 @@
include(CMakeParseArguments)
# clang format off
#
# ##############################################################################
# Build metal library
#
@@ -11,14 +9,11 @@ include(CMakeParseArguments)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
# files (like headers)
#
# clang format on
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
@@ -26,11 +21,7 @@ macro(mlx_build_metallib)
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metallib build command
add_custom_command(

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = NO
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES

View File

@@ -1,4 +0,0 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 6e9fcd3fd9a477c32d79521f0d5d7188
tags: 645f666f9bcd5a90fca523b33c5a78b7

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 746 KiB

View File

@@ -1,7 +0,0 @@
.. _cpp_ops:
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@@ -1,445 +0,0 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note::
Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using:
* The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
These will be added as function arguments.
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
Putting this all together, the generated function signature for ``myexp`` is as follows:
.. code-block:: cpp
template <typename T>
[[kernel]] void custom_kernel_myexp_float(
const device float16_t* inp [[buffer(0)]],
device float16_t* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
}
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides
-------------------
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
is ``True`` by default. This will copy the array inputs if needed
before the kernel is launched to ensure that the memory layout is row
contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source,
ensure_row_contiguous=False,
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Complex Example
-----------------------------
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
``55.7ms -> 6.7ms => 8x speed up``
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
* ``atomic_outputs=True``
Designate all of the kernel outputs as ``atomic`` in the function signature.
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
We can then implement the backwards pass as follows:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T gix_mult = W / 2;
T giy_mult = H / 2;
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:
``676.4ms -> 16.7ms => 40x speed up``

View File

@@ -1,811 +0,0 @@
Custom Extensions in MLX
========================
You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.
Introducing the Example
-----------------------
Let's say you would like an operation that takes in two arrays, ``x`` and
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
and then adds them together to get the result ``z = alpha * x + beta * y``.
You can do that in MLX directly:
.. code-block:: python
import mlx.core as mx
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
This function performs that operation while leaving the implementation and
function transformations to MLX.
However, you may want to customize the underlying implementation, perhaps to
make it faster. In this tutorial we will go through adding custom extensions.
It will cover:
* The structure of the MLX library.
* Implementing a CPU operation.
* Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python.
Operations and Primitives
-------------------------
Operations in MLX build the computation graph. Primitives provide the rules for
evaluating and transforming the graph. Let's start by discussing operations in
more detail.
Operations
^^^^^^^^^^^
Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++:
.. code-block:: C++
/**
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Use NumPy-style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
);
The simplest way to implement this is with existing operations:
.. code-block:: C++
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
The operations themselves do not contain the implementations that act on the
data, nor do they contain the rules of transformations. Rather, they are an
easy to use interface that use :class:`Primitive` building blocks.
Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete:
.. code-block:: C++
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** The name of primitive. */
const char* name() const override {
return "Axpby";
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
private:
float alpha_;
float beta_;
};
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
implementations of how the output array is produced given the inputs through
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
:meth:`Axpby::vmap`.
Using the Primitive
^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to the
computation graph. An :class:`array` can be constructed by providing its data
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
inputs that are passed to the primitive.
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype
: promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
auto out_shape = broadcasted_inputs[0].shape();
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
This operation now handles the following:
#. Upcast inputs and resolve the output data type.
#. Broadcast the inputs and resolve the output shape.
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
#. Construct the output :class:`array` using the primitive and the inputs.
Implementing the Primitive
--------------------------
No computation happens when we call the operation alone. The operation only
builds the computation graph. When we evaluate the output array, MLX schedules
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
.. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed.
Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing :meth:`Axpby::eval_cpu`.
The method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
template <typename T>
void axpby_impl(
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_,
mx::Stream stream) {
out.set_data(mx::allocator::malloc(out.nbytes()));
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(y);
encoder.set_output_array(out);
// Launch the CPU kernel
encoder.dispatch([x_ptr = x.data<T>(),
y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
});
}
Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
``complex64``. We throw an error if we encounter an unexpected type.
.. code-block:: C++
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here.
Implementing the GPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and
GPU kernels in MLX are written using Metal.
.. note::
Here are some helpful resources if you are new to Metal:
* A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_
Let's keep the GPU kernel simple. We will launch exactly as many threads as
there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the point-wise operation, and update its assigned
element in the output.
.. code-block:: C++
template <typename T>
[[kernel]] void axpby_general(
device const T* x [[buffer(0)]],
device const T* y [[buffer(1)]],
device T* out [[buffer(2)]],
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify it.
.. code-block:: C++
instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below.
.. code-block:: C++
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::stream kname;
kname = "axpby_general_" + type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
// Fix the 3D size of each threadgroup (in terms of threads)
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and Metal before moving on. MLX keeps track of
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
associated. We rely on :meth:`d.get_command_encoder` to give us the active
metal compute command encoder instead of building a new one and calling
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
the command buffer needs to be flushed for synchronization.
Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^
Next, let's add implementations for transformations in a :class:`Primitive`.
These transformations can be built on top of other operations, including the
one we just defined:
.. code-block:: C++
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
// Similarly, if argnums = {1}, the jvp is just the tangent
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
}
}
.. code-block:: C++
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
Note, a transformation does not need to be fully defined to start using
the :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Axpby] vmap not implemented.");
}
Building and Binding
--------------------
Let's look at the overall directory structure first.
| extensions
| ├── axpby
| │ ├── axpby.cpp
| │ ├── axpby.h
| │ └── axpby.metal
| ├── mlx_sample_extensions
| │ └── __init__.py
| ├── bindings.cpp
| ├── CMakeLists.txt
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated Python package
* ``extensions/bindings.cpp`` provides Python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
Python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the Python package
Binding to Python
^^^^^^^^^^^^^^^^^^
We use nanobind_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple.
.. code-block:: C++
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
Args:
x (array): Input array.
y (array): Input array.
alpha (float): Scaling factor for ``x``.
beta (float): Scaling factor for ``y``.
Returns:
array: ``alpha * x + beta * y``
)");
}
Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.
.. warning::
:mod:`mlx.core` must be imported before importing
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available.
.. _Building with CMake:
Building with CMake
^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library only requires that you ``find_package(MLX
CONFIG)`` and then link it to your library.
.. code-block:: cmake
# Add library
add_library(mlx_ext)
# Add sources
target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers
target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached Metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
Here is what that looks like in practice:
.. code-block:: cmake
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib
)
endif()
Finally, we build the nanobind_ bindings
.. code-block:: cmake
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()
Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension`:
.. code-block:: python
from mlx import extension
from setuptools import setup
if __name__ == "__main__":
setup(
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev":[]},
zip_safe=False,
python_requires=">=3.8",
)
.. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
To build the package, first install the build dependencies with ``pip install
-r requirements.txt``. You can then build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This results in the directory structure:
| extensions
| ├── mlx_sample_extensions
| │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library
| │ └── _ext.cpython-3x-darwin.so # Python Binding
| ...
When you try to install using the command ``python -m pip install .`` (in
``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
copied along with the Python binding since they are specified as
``package_data``.
Usage
-----
After installing the extension as described above, you should be able to simply
import the Python package and play with it as you would any other MLX operation.
Let's look at a simple script and its results:
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c is correct: {mx.all(c == 6.0).item()}")
Output:
.. code-block::
c shape: [3, 4]
c dtype: float32
c is correct: True
Results
^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined.
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
import time
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
M = 4096
N = 4096
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
alpha = 4.0
beta = 2.0
mx.eval(x, y)
def bench(f):
# Warm up
for i in range(5):
z = f(x, y, alpha, beta)
mx.eval(z)
# Timed run
s = time.time()
for i in range(100):
z = f(x, y, alpha, beta)
mx.eval(z)
e = time.time()
return 1000 * (e - s) / 100
simple_time = bench(simple_axpby)
custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
modest improvements right away!
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`.
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _nanobind: https://nanobind.readthedocs.io/en/latest/

View File

@@ -1,68 +0,0 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
# that the path trace_file does not already exist.
mx.metal.start_capture(trace_file)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@@ -1,121 +0,0 @@
.. _mlx_in_cpp:
Using MLX in C++
================
You can use MLX in a C++ project with CMake.
.. note::
This guide is based one the following `example using MLX in C++
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
First install MLX:
.. code-block:: bash
pip install -U mlx
You can also install the MLX Python package from source or just the C++
library. For more information see the :ref:`documentation on installing MLX
<build_and_install>`.
Next make an example program in ``example.cpp``:
.. code-block:: C++
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}
The next step is to setup a CMake file in ``CMakeLists.txt``:
.. code-block:: cmake
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Depending on how you installed MLX, you may need to tell CMake where to
find it.
If you installed MLX with Python, then add the following to the CMake file:
.. code-block:: cmake
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
If you installed the MLX C++ package to a system path, then CMake should be
able to find it. If you installed it to a non-standard location or CMake can't
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
.. code-block:: cmake
set(MLX_ROOT "/path/to/mlx/")
Next, instruct CMake to find MLX:
.. code-block:: cmake
find_package(MLX CONFIG REQUIRED)
Finally, add the ``example.cpp`` program as an executable and link MLX.
.. code-block:: cmake
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)
You can build the example with:
.. code-block:: bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
And run it with:
.. code-block:: bash
./build/example
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
.. list-table:: Package Variables
:widths: 20 20
:header-rows: 1
* - Variable
- Description
* - MLX_FOUND
- ``True`` if MLX is found
* - MLX_INCLUDE_DIRS
- Include directory
* - MLX_LIBRARIES
- Libraries to link against
* - MLX_CXX_FLAGS
- Additional compiler flags
* - MLX_BUILD_ACCELERATE
- ``True`` if MLX was built with Accelerate
* - MLX_BUILD_METAL
- ``True`` if MLX was built with Metal

View File

@@ -1,77 +0,0 @@
.. _linear_regression:
Linear Regression
-----------------
Let's implement a basic linear regression model as a starting point to
learn MLX. First import the core package and setup some problem metadata:
.. code-block:: python
import mlx.core as mx
num_features = 100
num_examples = 1_000
num_iters = 10_000 # iterations of SGD
lr = 0.01 # learning rate for SGD
We'll generate a synthetic dataset by:
1. Sampling the design matrix ``X``.
2. Sampling a ground truth parameter vector ``w_star``.
3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``.
.. code-block:: python
# True parameters
w_star = mx.random.normal((num_features,))
# Input examples (design matrix)
X = mx.random.normal((num_examples, num_features))
# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
y = X @ w_star + eps
We will use SGD to find the optimal weights. To start, define the squared loss
and get the gradient function of the loss with respect to the parameters.
.. code-block:: python
def loss_fn(w):
return 0.5 * mx.mean(mx.square(X @ w - y))
grad_fn = mx.grad(loss_fn)
Start the optimization by initializing the parameters ``w`` randomly. Then
repeatedly update the parameters for ``num_iters`` iterations.
.. code-block:: python
w = 1e-2 * mx.random.normal((num_features,))
for _ in range(num_iters):
grad = grad_fn(w)
w = w - lr * grad
mx.eval(w)
Finally, compute the loss of the learned parameters and verify that they are
close to the ground truth parameters.
.. code-block:: python
loss = loss_fn(w)
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
print(
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
)
# Should print something close to: Loss 0.00005, |w-w*| = 0.00364
Complete `linear regression
<https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py>`_
and `logistic regression
<https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py>`_
examples are available in the MLX GitHub repo.

View File

@@ -1,382 +0,0 @@
LLM inference
==============
MLX enables efficient inference of large-ish transformers on Apple silicon
without compromising on ease of use. In this example we will create an
inference script for the Llama family of transformer models in which the model
is defined in less than 200 lines of python.
Implementing the model
----------------------
We will use the neural network building blocks defined in the :mod:`mlx.nn`
module to concisely define the model architecture.
Attention layer
^^^^^^^^^^^^^^^^
We will start with the Llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.
Our implementation uses :class:`mlx.nn.Linear` for all the projections and
:class:`mlx.nn.RoPE` for the positional encoding.
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, traditional=True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# Note that we return the keys and values to possibly be used as a cache
return self.out_proj(values_hat), (keys, values)
Encoder layer
^^^^^^^^^^^^^
The other component of the Llama model is the encoder layer which uses RMS
normalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use
:class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`.
.. code-block:: python
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
Full model
^^^^^^^^^^
To implement any Llama model we simply have to combine ``LlamaEncoderLayer``
instances with an :class:`mlx.nn.Embedding` to embed the input tokens.
.. code-block:: python
class Llama(nn.Module):
def __init__(
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.layers = [
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
]
self.norm = nn.RMSNorm(dims)
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
x = self.embedding(x)
for l in self.layers:
x, _ = l(x, mask)
x = self.norm(x)
return self.out_proj(x)
Note that in the implementation above we use a simple list to hold the encoder
layers but using ``model.parameters()`` will still consider these layers.
Generation
^^^^^^^^^^^
Our ``Llama`` module can be used for training but not inference as the
``__call__`` method above processes one input, completely ignores the cache and
performs no sampling whatsoever. In the rest of this subsection, we will
implement the inference function as a python generator that processes the
prompt and then autoregressively yields tokens one at a time.
.. code-block:: python
class Llama(nn.Module):
...
def generate(self, x, temp=1.0):
cache = []
# Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
# First we process the prompt x the same way as in __call__ but
# save the caches in cache
x = self.embedding(x)
for l in self.layers:
x, c = l(x, mask=mask)
cache.append(c) # <--- we store the per layer cache in a
# simple python list
x = self.norm(x)
y = self.out_proj(x[:, -1]) # <--- we only care about the last logits
# that generate the next token
y = mx.random.categorical(y * (1/temp))
# y now has size [1]
# Since MLX is lazily evaluated nothing is computed yet.
# Calling y.item() would force the computation to happen at
# this point but we can also choose not to do that and let the
# user choose when to start the computation.
yield y
# Now we parsed the prompt and generated the first token we
# need to feed it back into the model and loop to generate the
# rest.
while True:
# Unsqueezing the last dimension to add a sequence length
# dimension of 1
x = y[:, None]
x = self.embedding(x)
for i in range(len(cache)):
# We are overwriting the arrays in the cache list. When
# the computation will happen, MLX will be discarding the
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1/temp))
yield y
Putting it all together
^^^^^^^^^^^^^^^^^^^^^^^
We now have everything we need to create a Llama model and sample tokens from
it. In the following code, we randomly initialize a small Llama model, process
6 tokens of prompt and generate 10 tokens.
.. code-block:: python
model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)
# Since MLX is lazily evaluated nothing has actually been materialized yet.
# We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the
# code above would still run. Let's actually materialize the model.
mx.eval(model.parameters())
prompt = mx.array([[1, 10, 8, 32, 44, 7]]) # <-- Note the double brackets because we
# have a batch dimension even
# though it is 1 in this case
generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]
# Since we haven't evaluated anything, nothing is computed yet. The list
# `generated` contains the arrays that hold the computation graph for the
# full processing of the prompt and the generation of 10 tokens.
#
# We can evaluate them one at a time, or all together. Concatenate them or
# print them. They would all result in very similar runtimes and give exactly
# the same results.
mx.eval(generated)
Converting the weights
----------------------
This section assumes that you have access to the original Llama weights and the
SentencePiece model that comes with them. We will write a small script to
convert the PyTorch weights to MLX compatible ones and write them in a NPZ file
that can be loaded directly by MLX.
.. code-block:: python
import argparse
from itertools import starmap
import numpy as np
import torch
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return key, value.numpy()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
args = parser.parse_args()
state = torch.load(args.torch_weights)
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
)
Weight loading and benchmarking
-------------------------------
After converting the weights to be compatible to our implementation, all that is
left is to load them from disk and we can finally use the LLM to generate text.
We can load numpy format files using the :func:`mlx.core.load` operation.
To create a parameter dictionary from the key/value representation of NPZ files
we will use the :func:`mlx.utils.tree_unflatten` helper method as follows:
.. code-block:: python
from mlx.utils import tree_unflatten
model.update(tree_unflatten(list(mx.load(weight_file).items())))
:meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look
like ``layers.2.attention.query_proj.weight`` and will transform them to
.. code-block:: python
{"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}
which can then be used to update the model. Note that the method above incurs
several unnecessary copies from disk to numpy and then from numpy to MLX. It
will be replaced in the future with direct loading to MLX.
You can download the full example code in `mlx-examples`_. Assuming, the
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
directory we can play around with our inference script as follows (the timings
are representative of an M1 Ultra and the 7B parameter Llama model):
.. code-block:: bash
$ python convert.py weights.pth llama-7B.mlx.npz
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'
[INFO] Loading model from disk: 5.247 s
Press enter to start generation
------
, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,
------
[INFO] Prompt processing: 0.437 s
[INFO] Full generation: 4.330 s
We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds
of those are spent processing the prompt. This amounts to a little over **39 ms
per token**.
By running with a much bigger prompt we can see that the per token generation
time as well as the prompt processing time remains almost constant.
.. code-block:: bash
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
[INFO] Loading model from disk: 5.247 s
Press enter to start generation
------
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not
------
[INFO] Prompt processing: 0.579 s
[INFO] Full generation: 4.690 s
$ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
[INFO] Loading model from disk: 5.628 s
Press enter to start generation
------
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “
------
[INFO] Prompt processing: 0.633 s
[INFO] Full generation: 21.475 s
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx-examples`_.
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
Roformer: Enhanced transformer with rotary position embedding. arXiv
preprint arXiv:2104.09864.
.. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.
Advances in Neural Information Processing Systems, 32.
.. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint
arXiv:2002.05202.

View File

@@ -1,134 +0,0 @@
.. _mlp:
Multi-Layer Perceptron
----------------------
In this example we'll learn to use ``mlx.nn`` by implementing a simple
multi-layer perceptron to classify MNIST.
As a first step import the MLX packages we need:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
The model is defined as the ``MLP`` class which inherits from
:class:`mlx.nn.Module`. We follow the standard idiom to make a new module:
1. Define an ``__init__`` where the parameters and/or submodules are setup. See
the :ref:`Module class docs<module_class>` for more information on how
:class:`mlx.nn.Module` registers parameters.
2. Define a ``__call__`` where the computation is implemented.
.. code-block:: python
class MLP(nn.Module):
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = mx.maximum(l(x), 0.0)
return self.layers[-1](x)
We define the loss function which takes the mean of the per-example cross
entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some
commonly used loss functions.
.. code-block:: python
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))
We also need a function to compute the accuracy of the model on the validation
set:
.. code-block:: python
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as ``mnist``.
.. code-block:: python
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1
# Load the data
import mnist
train_images, train_labels, test_images, test_labels = map(
mx.array, mnist.mnist()
)
Since we're using SGD, we need an iterator which shuffles and constructs
minibatches of examples in the training set:
.. code-block:: python
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
Finally, we put it all together by instantiating the model, the
:class:`mlx.optimizers.SGD` optimizer, and running the training loop:
.. code-block:: python
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
# Get a function which gives the loss and gradient of the
# loss with respect to the model's trainable parameters
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Instantiate the optimizer
optimizer = optim.SGD(learning_rate=learning_rate)
for e in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
# Update the optimizer state and model parameters
# in a single call
optimizer.update(model, grads)
# Force a graph evaluation
mx.eval(model.parameters(), optimizer.state)
accuracy = eval_fn(model, test_images, test_labels)
print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
.. note::
The :func:`mlx.nn.value_and_grad` function is a convenience function to get
the gradient of a loss with respect to the trainable parameters of a model.
This should not be confused with :func:`mlx.core.value_and_grad`.
The model should train to a decent accuracy (about 95%) after just a few passes
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mnist>`_
is available in the MLX GitHub repo.

View File

@@ -1,93 +0,0 @@
MLX
===
MLX is a NumPy-like array framework designed for efficient and flexible machine
learning on Apple silicon, brought to you by Apple machine learning research.
The Python API closely follows NumPy with a few exceptions. MLX also has a
fully featured C++ API which closely follows the Python API.
The main differences between MLX and NumPy are:
- **Composable function transformations**: MLX has composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Multi-device**: Operations can run on any of the supported devices (CPU,
GPU, ...)
The design of MLX is inspired by frameworks like `PyTorch
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
memory. Operations on MLX arrays can be performed on any of the supported
device types without performing data copies. Currently supported device types
are the CPU and GPU.
.. toctree::
:caption: Install
:maxdepth: 1
install
.. toctree::
:caption: Usage
:maxdepth: 1
usage/quick_start
usage/lazy_evaluation
usage/unified_memory
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
usage/export
.. toctree::
:caption: Examples
:maxdepth: 1
examples/linear_regression
examples/mlp
examples/llama-inference
.. toctree::
:caption: Python API Reference
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/export
python/ops
python/random
python/transforms
python/fast
python/fft
python/linalg
python/metal
python/cuda
python/memory_management
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::
:caption: C++ API Reference
:maxdepth: 1
cpp/ops
.. toctree::
:caption: Further Reading
:maxdepth: 1
dev/extensions
dev/metal_debugger
dev/custom_metal_kernels
dev/mlx_in_cpp

View File

@@ -1,345 +0,0 @@
.. _build_and_install:
Build and Install
=================
Python Installation
-------------------
MLX is available on PyPI. All you have to do to use MLX with your own Apple
silicon computer is
.. code-block:: shell
pip install mlx
To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
- macOS >= 13.5
.. note::
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
CUDA
^^^^
MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
Troubleshooting
^^^^^^^^^^^^^^^
*My OS and Python versions are in the required range but pip still does not find
a matching distribution.*
Probably you are using a non-native Python. The output of
.. code-block:: shell
python -c "import platform; print(platform.processor())"
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
are using a non-native Python. Switch your Python to a native Python. A good
way to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.
Build from source
-----------------
Build Requirements
^^^^^^^^^^^^^^^^^^
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.
Python API
^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
.. code-block:: shell
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Then simply build and install MLX using pip:
.. code-block:: shell
pip install .
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
python setup.py build_ext --inplace
Run the tests with:
.. code-block:: shell
python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your
IDE:
.. code-block:: shell
python setup.py generate_stubs
C++ API
^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start
by cloning MLX from `its GitHub repo
<https://github.com/ml-explore/mlx>`_:
.. code-block:: shell
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
Run tests with:
.. code-block:: shell
make test
Install with:
.. code-block:: shell
make install
Note that the built ``mlx.metallib`` file should be either at the same
directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
* - Option
- Default
* - MLX_BUILD_TESTS
- ON
* - MLX_BUILD_EXAMPLES
- OFF
* - MLX_BUILD_BENCHMARKS
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
.. code-block:: shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~
You see the following error when you try to build:
.. code-block:: shell
error: unable to find utility "metal", not a developer tool or in PATH
To fix this, first make sure you have Xcode installed:
.. code-block:: shell
xcode-select --install
Then set the active developer directory:
.. code-block:: shell
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
~~~~~~~~~
.. _build shell:
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
terminal.
Verify the terminal is now running natively the following command:
.. code-block:: shell
$ uname -p
arm
Also check that cmake is using the correct architecture:
.. code-block:: shell
$ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR
CMAKE_HOST_SYSTEM_PROCESSOR "arm64"
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cache with ``rm -rf build/`` and try again.

View File

@@ -1,28 +0,0 @@
mlx.core.Device
===============
.. currentmodule:: mlx.core
.. autoclass:: Device
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~Device.__init__
.. rubric:: Attributes
.. autosummary::
~Device.type

View File

@@ -1,28 +0,0 @@
mlx.core.Dtype
==============
.. currentmodule:: mlx.core
.. autoclass:: Dtype
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~Dtype.__init__
.. rubric:: Attributes
.. autosummary::
~Dtype.size

View File

@@ -1,29 +0,0 @@
mlx.core.DtypeCategory
======================
.. currentmodule:: mlx.core
.. autoclass:: DtypeCategory
.. automethod:: __init__
.. rubric:: Attributes
.. autosummary::
~DtypeCategory.complexfloating
~DtypeCategory.floating
~DtypeCategory.inexact
~DtypeCategory.signedinteger
~DtypeCategory.unsignedinteger
~DtypeCategory.integer
~DtypeCategory.number
~DtypeCategory.generic

View File

@@ -1,6 +0,0 @@
mlx.core.abs
============
.. currentmodule:: mlx.core
.. autofunction:: abs

View File

@@ -1,6 +0,0 @@
mlx.core.add
============
.. currentmodule:: mlx.core
.. autofunction:: add

View File

@@ -1,6 +0,0 @@
mlx.core.addmm
==============
.. currentmodule:: mlx.core
.. autofunction:: addmm

View File

@@ -1,6 +0,0 @@
mlx.core.all
============
.. currentmodule:: mlx.core
.. autofunction:: all

View File

@@ -1,6 +0,0 @@
mlx.core.allclose
=================
.. currentmodule:: mlx.core
.. autofunction:: allclose

View File

@@ -1,6 +0,0 @@
mlx.core.any
============
.. currentmodule:: mlx.core
.. autofunction:: any

View File

@@ -1,6 +0,0 @@
mlx.core.arange
===============
.. currentmodule:: mlx.core
.. autofunction:: arange

View File

@@ -1,6 +0,0 @@
mlx.core.arccos
===============
.. currentmodule:: mlx.core
.. autofunction:: arccos

View File

@@ -1,6 +0,0 @@
mlx.core.arccosh
================
.. currentmodule:: mlx.core
.. autofunction:: arccosh

View File

@@ -1,6 +0,0 @@
mlx.core.arcsin
===============
.. currentmodule:: mlx.core
.. autofunction:: arcsin

View File

@@ -1,6 +0,0 @@
mlx.core.arcsinh
================
.. currentmodule:: mlx.core
.. autofunction:: arcsinh

View File

@@ -1,6 +0,0 @@
mlx.core.arctan
===============
.. currentmodule:: mlx.core
.. autofunction:: arctan

View File

@@ -1,6 +0,0 @@
mlx.core.arctan2
================
.. currentmodule:: mlx.core
.. autofunction:: arctan2

View File

@@ -1,6 +0,0 @@
mlx.core.arctanh
================
.. currentmodule:: mlx.core
.. autofunction:: arctanh

View File

@@ -1,6 +0,0 @@
mlx.core.argmax
===============
.. currentmodule:: mlx.core
.. autofunction:: argmax

View File

@@ -1,6 +0,0 @@
mlx.core.argmin
===============
.. currentmodule:: mlx.core
.. autofunction:: argmin

View File

@@ -1,6 +0,0 @@
mlx.core.argpartition
=====================
.. currentmodule:: mlx.core
.. autofunction:: argpartition

View File

@@ -1,6 +0,0 @@
mlx.core.argsort
================
.. currentmodule:: mlx.core
.. autofunction:: argsort

View File

@@ -1,6 +0,0 @@
mlx.core.array.T
================
.. currentmodule:: mlx.core
.. autoproperty:: array.T

View File

@@ -1,6 +0,0 @@
mlx.core.array.abs
==================
.. currentmodule:: mlx.core
.. automethod:: array.abs

View File

@@ -1,6 +0,0 @@
mlx.core.array.all
==================
.. currentmodule:: mlx.core
.. automethod:: array.all

View File

@@ -1,6 +0,0 @@
mlx.core.array.any
==================
.. currentmodule:: mlx.core
.. automethod:: array.any

View File

@@ -1,6 +0,0 @@
mlx.core.array.argmax
=====================
.. currentmodule:: mlx.core
.. automethod:: array.argmax

View File

@@ -1,6 +0,0 @@
mlx.core.array.argmin
=====================
.. currentmodule:: mlx.core
.. automethod:: array.argmin

View File

@@ -1,6 +0,0 @@
mlx.core.array.astype
=====================
.. currentmodule:: mlx.core
.. automethod:: array.astype

View File

@@ -1,6 +0,0 @@
mlx.core.array.at
=================
.. currentmodule:: mlx.core
.. autoproperty:: array.at

View File

@@ -1,6 +0,0 @@
mlx.core.array.conj
===================
.. currentmodule:: mlx.core
.. automethod:: array.conj

View File

@@ -1,6 +0,0 @@
mlx.core.array.cos
==================
.. currentmodule:: mlx.core
.. automethod:: array.cos

View File

@@ -1,6 +0,0 @@
mlx.core.array.cummax
=====================
.. currentmodule:: mlx.core
.. automethod:: array.cummax

View File

@@ -1,6 +0,0 @@
mlx.core.array.cummin
=====================
.. currentmodule:: mlx.core
.. automethod:: array.cummin

View File

@@ -1,6 +0,0 @@
mlx.core.array.cumprod
======================
.. currentmodule:: mlx.core
.. automethod:: array.cumprod

View File

@@ -1,6 +0,0 @@
mlx.core.array.cumsum
=====================
.. currentmodule:: mlx.core
.. automethod:: array.cumsum

View File

@@ -1,6 +0,0 @@
mlx.core.array.diag
===================
.. currentmodule:: mlx.core
.. automethod:: array.diag

View File

@@ -1,6 +0,0 @@
mlx.core.array.diagonal
=======================
.. currentmodule:: mlx.core
.. automethod:: array.diagonal

View File

@@ -1,6 +0,0 @@
mlx.core.array.dtype
====================
.. currentmodule:: mlx.core
.. autoproperty:: array.dtype

View File

@@ -1,6 +0,0 @@
mlx.core.array.exp
==================
.. currentmodule:: mlx.core
.. automethod:: array.exp

View File

@@ -1,6 +0,0 @@
mlx.core.array.flatten
======================
.. currentmodule:: mlx.core
.. automethod:: array.flatten

View File

@@ -1,6 +0,0 @@
mlx.core.array.imag
===================
.. currentmodule:: mlx.core
.. autoproperty:: array.imag

View File

@@ -1,6 +0,0 @@
mlx.core.array.item
===================
.. currentmodule:: mlx.core
.. automethod:: array.item

View File

@@ -1,6 +0,0 @@
mlx.core.array.itemsize
=======================
.. currentmodule:: mlx.core
.. autoproperty:: array.itemsize

View File

@@ -1,6 +0,0 @@
mlx.core.array.log
==================
.. currentmodule:: mlx.core
.. automethod:: array.log

View File

@@ -1,6 +0,0 @@
mlx.core.array.log10
====================
.. currentmodule:: mlx.core
.. automethod:: array.log10

View File

@@ -1,6 +0,0 @@
mlx.core.array.log1p
====================
.. currentmodule:: mlx.core
.. automethod:: array.log1p

View File

@@ -1,6 +0,0 @@
mlx.core.array.log2
===================
.. currentmodule:: mlx.core
.. automethod:: array.log2

View File

@@ -1,6 +0,0 @@
mlx.core.array.logcumsumexp
===========================
.. currentmodule:: mlx.core
.. automethod:: array.logcumsumexp

View File

@@ -1,6 +0,0 @@
mlx.core.array.logsumexp
========================
.. currentmodule:: mlx.core
.. automethod:: array.logsumexp

View File

@@ -1,6 +0,0 @@
mlx.core.array.max
==================
.. currentmodule:: mlx.core
.. automethod:: array.max

View File

@@ -1,6 +0,0 @@
mlx.core.array.mean
===================
.. currentmodule:: mlx.core
.. automethod:: array.mean

View File

@@ -1,6 +0,0 @@
mlx.core.array.min
==================
.. currentmodule:: mlx.core
.. automethod:: array.min

View File

@@ -1,6 +0,0 @@
mlx.core.array.moveaxis
=======================
.. currentmodule:: mlx.core
.. automethod:: array.moveaxis

View File

@@ -1,6 +0,0 @@
mlx.core.array.nbytes
=====================
.. currentmodule:: mlx.core
.. autoproperty:: array.nbytes

View File

@@ -1,6 +0,0 @@
mlx.core.array.ndim
===================
.. currentmodule:: mlx.core
.. autoproperty:: array.ndim

View File

@@ -1,6 +0,0 @@
mlx.core.array.prod
===================
.. currentmodule:: mlx.core
.. automethod:: array.prod

View File

@@ -1,6 +0,0 @@
mlx.core.array.real
===================
.. currentmodule:: mlx.core
.. autoproperty:: array.real

View File

@@ -1,6 +0,0 @@
mlx.core.array.reciprocal
=========================
.. currentmodule:: mlx.core
.. automethod:: array.reciprocal

View File

@@ -1,6 +0,0 @@
mlx.core.array.reshape
======================
.. currentmodule:: mlx.core
.. automethod:: array.reshape

View File

@@ -1,6 +0,0 @@
mlx.core.array.round
====================
.. currentmodule:: mlx.core
.. automethod:: array.round

View File

@@ -1,6 +0,0 @@
mlx.core.array.rsqrt
====================
.. currentmodule:: mlx.core
.. automethod:: array.rsqrt

Some files were not shown because too many files have changed in this diff Show More