Compare commits

..

No commits in common. "main" and "v0.5.1" have entirely different histories.
main ... v0.5.1

677 changed files with 29380 additions and 104248 deletions

View File

@ -13,65 +13,8 @@ parameters:
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
cuda_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "16.2.0"
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
@ -88,67 +31,53 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
- run:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
python3 setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
- run:
name: Build CPP only
command: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
- run:
name: Run CPP tests
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
xcode: "15.2.0"
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install torch
pip install tensorflow
@ -157,86 +86,35 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build example extension
command: |
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3.11 -m pip install .
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
mkdir -p build && cd build && cmake .. && make -j
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
cuda_build_and_test:
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Install Python package
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
DEVICE=cpu ./build/tests/tests
build_release:
parameters:
@ -245,31 +123,26 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "16.2.0"
default: "15.2.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
pip install build
@ -277,21 +150,20 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build -w
- when:
condition: << parameters.build_env >>
@ -304,7 +176,7 @@ jobs:
- store_artifacts:
path: dist/
build_linux_release:
build_linux_test_release:
parameters:
python_version:
type: string
@ -333,70 +205,22 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
python_version:
type: string
default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Build wheel
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
source env/bin/activate
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install ".[dev]" -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel
bash python/scripts/repair_cuda.sh
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
@ -411,13 +235,8 @@ workflows:
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- mac_build_and_test
- linux_build_and_test
- cuda_build_and_test
- build_documentation
build_pypi_release:
when:
@ -434,79 +253,9 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb:
when:
matches:
@ -519,13 +268,8 @@ workflows:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
nightly_build:
when:
and:
@ -535,55 +279,8 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
weekly_build:
when:
and:
@ -593,90 +290,17 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
- << pipeline.parameters.test_release >>
jobs:
- build_linux_release:
- build_linux_test_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"]

View File

@ -17,4 +17,4 @@ jobs:
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files
pre-commit run --all-files

4
.gitignore vendored
View File

@ -36,7 +36,6 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
uv.lock
# vim
*.swp
@ -77,9 +76,6 @@ build/
*.out
*.app
# Debug symbols
*.pdb
# VSCode
.vscode/
.DS_Store

View File

@ -1,21 +1,16 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
rev: v17.0.6
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.1.0
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.0
rev: 5.13.2
hooks:
- id: isort
args:
- --profile=black
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
- id: cmake-format

View File

@ -7,18 +7,12 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@ -1,24 +0,0 @@
cff-version: 1.2.0
title: mlx
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Awni
family-names: Hannun
affiliation: Apple
- given-names: Jagrit
family-names: Digani
affiliation: Apple
- given-names: Angelos
family-names: Katharopoulos
affiliation: Apple
- given-names: Ronan
family-names: Collobert
affiliation: Apple
repository-code: 'https://github.com/ml-explore'
abstract: >-
MLX: efficient and flexible machine learning on Apple
silicon
license: MIT

View File

@ -1,24 +1,6 @@
cmake_minimum_required(VERSION 3.25)
cmake_minimum_required(VERSION 3.24)
if(NOT MLX_VERSION)
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
set(_major ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
set(_minor ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION})
endif()
project(
mlx
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
project(mlx LANGUAGES C CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
@ -33,40 +15,36 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
# --------------------- Processor tests -------------------------
message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.5.1)
endif()
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
message(
FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
# --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, "
" make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
@ -78,227 +56,150 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
if (MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
)
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
mlx PUBLIC
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
set(MLX_BUILD_GGUF OFF)
# There is no prebuilt OpenBLAS distribution for MSVC.
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
endif()
# Windows implementation of dlfcn.h APIs.
FetchContent_Declare(
dlfcn-win32
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.1
EXCLUDE_FROM_ALL)
block()
set(BUILD_SHARED_LIBS OFF)
FetchContent_MakeAvailable(dlfcn-win32)
endblock()
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
target_link_libraries(mlx PRIVATE dl)
endif()
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(MLX_USE_ACCELERATE)
add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code.
FetchContent_Declare(
openblas
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
GIT_TAG v0.3.28
EXCLUDE_FROM_ALL)
set(BUILD_STATIC_LIBS ON) # link statically
set(NOFORTRAN ON) # msvc has no fortran compiler
FetchContent_MakeAvailable(openblas)
target_link_libraries(mlx PRIVATE openblas)
target_include_directories(
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
else()
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if(NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if(NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
endif()
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
#set(BLA_VENDOR Generic)
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()
if(MLX_BUILD_TESTS)
if (MLX_BUILD_TESTS)
include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif()
if(MLX_BUILD_EXAMPLES)
if (MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif()
if(MLX_BUILD_BENCHMARKS)
if (MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif()
# ----------------------------- Installation -----------------------------
include(GNUInstallDirs)
# Install library
install(
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
# Install headers
install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING
PATTERN "*.h"
PATTERN "backend/metal/kernels.h" EXCLUDE)
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING PATTERN "*.h"
)
# Install metal dependencies
if(MLX_BUILD_METAL)
if (MLX_BUILD_METAL)
# Install metal cpp
install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source)
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source
)
endif()
@ -310,24 +211,31 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install(
EXPORT MLXTargets
FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
include(CMakePackageConfigHelpers)
write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION})
VERSION ${MLX_VERSION}
)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
MLX_CMAKE_INSTALL_MODULE_DIR)
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
)
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
install(
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install(DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
install(
DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)

View File

@ -5,26 +5,26 @@ possible.
## Pull Requests
1. Fork and submit pull requests to the repo.
1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review.
5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
```shell
clang-format -i file.cpp
```
```shell
black file.py
```
```
clang-format -i file.cpp
```
```
black file.py
```
or run `pre-commit run --all-files` to check all files in the repo.
## Issues

View File

@ -1,6 +1,4 @@
include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ *
include cmake/*
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@ -6,7 +6,7 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning on Apple silicon,
MLX is an array framework for machine learning research on Apple silicon,
brought to you by Apple machine learning research.
Some key features of MLX include:
@ -88,13 +88,13 @@ for more information on building the C++ and Python APIs from source.
## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
on contributing to MLX. See the
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
information on building from source, and running tests.
We are grateful for all of [our
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.

View File

@ -5,35 +5,35 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_value_and_grad() {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
auto x = ones({200, 1000});
eval(x);
auto fn = [](array x) {
for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x));
x = log(exp(x));
}
return mx::sum(x);
return sum(x);
};
auto grad_fn = mx::grad(fn);
auto grad_fn = grad(fn);
auto independent_value_and_grad = [&]() {
auto value = fn(x);
auto dfdx = grad_fn(x);
return std::vector<mx::array>{value, dfdx};
return std::vector<array>{value, dfdx};
};
TIME(independent_value_and_grad);
auto value_and_grad_fn = mx::value_and_grad(fn);
auto value_and_grad_fn = value_and_grad(fn);
auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<mx::array>{value, dfdx};
return std::vector<array>{value, dfdx};
};
TIME(combined_value_and_grad);
}
int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
std::cout << "Benchmarks for " << default_device() << std::endl;
time_value_and_grad();
}

View File

@ -4,21 +4,21 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_add_op() {
std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back());
}
set_default_device(mx::Device::cpu);
set_default_device(Device::cpu);
for (auto size : sizes) {
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
std::cout << "Size " << size << std::endl;
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
TIMEM("cpu", add, a, b, Device::cpu);
TIMEM("gpu", add, a, b, Device::gpu);
}
}

View File

@ -1,111 +1,110 @@
// Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream>
#include <sstream>
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_irregular_binary_ops_1D() {
auto device = mx::default_device();
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", mx::add, a, b, device);
TIMEM("1D strided", add, a, b, device);
}
void time_irregular_binary_ops_2D() {
auto device = mx::default_device();
auto device = default_device();
int size = 2048;
auto a = mx::random::uniform({size, size});
auto b = mx::random::uniform({size, size});
mx::eval(a, b);
TIMEM("2D regular", mx::add, a, b, device);
auto a = random::uniform({size, size});
auto b = random::uniform({size, size});
eval(a, b);
TIMEM("2D regular", add, a, b, device);
b = mx::transpose(b);
mx::eval(b);
TIMEM("2D mx::transpose", mx::add, a, b, device);
b = transpose(b);
eval(b);
TIMEM("2D transpose", add, a, b, device);
b = mx::random::uniform({size});
mx::eval(b);
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({size});
eval(b);
TIMEM("2D broadcast dim 0", add, a, b, device);
b = mx::reshape(b, {size, 1});
mx::eval(b);
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
b = reshape(b, {size, 1});
eval(b);
TIMEM("2D broadcast dim 1", add, a, b, device);
}
void time_irregular_binary_ops_3D() {
auto device = mx::default_device();
auto device = default_device();
int d0 = 32;
int d1 = 512;
int d2 = 512;
auto a = mx::random::uniform({d0, d1, d2});
auto b = mx::random::uniform({d0, d1, d2});
TIMEM("3D regular", mx::add, a, b, device);
auto a = random::uniform({d0, d1, d2});
auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", add, a, b, device);
b = mx::transpose(b, {0, 2, 1});
TIMEM("3D mx::transpose", mx::add, a, b, device);
b = transpose(b, {0, 2, 1});
TIMEM("3D transpose", add, a, b, device);
b = mx::random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", add, a, b, device);
b = mx::random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", add, a, b, device);
b = mx::random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
b = mx::random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
b = mx::random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
}
void time_irregular_binary_ops_4D() {
auto device = mx::default_device();
auto device = default_device();
std::vector<int> shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
auto a = random::uniform(shape);
auto b = random::uniform(shape);
TIMEM("4D regular", mx::add, a, b, device);
TIMEM("4D regular", add, a, b, device);
b = mx::transpose(b, {0, 1, 3, 2});
TIMEM("4D mx::transpose", mx::add, a, b, device);
b = transpose(b, {0, 1, 3, 2});
TIMEM("4D transpose", add, a, b, device);
std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1;
b = mx::random::uniform(shape);
b = random::uniform(shape);
std::ostringstream msg;
msg << om << i;
TIMEM(msg.str(), mx::add, a, b, device);
TIMEM(msg.str(), add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1;
std::ostringstream msg;
msg << om << i << ", " << j;
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1;
std::ostringstream msg;
msg << om << i << ", " << j << ", " << k;
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[k] = a.shape(k);
}
}
@ -114,83 +113,83 @@ void time_irregular_binary_ops_4D() {
}
void time_irregular_reshape() {
auto device = mx::default_device();
auto device = default_device();
std::vector<int> shape;
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
auto reshape_fn = [&shape, device](const array& a) {
return reshape(a, shape, device);
};
int size = 64;
int d = 2 * size;
auto a = mx::random::uniform({d, d, d});
auto a = random::uniform({d, d, d});
shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a);
a = mx::transpose(a);
a = transpose(a);
shape = {8 * size, size, size};
TIMEM("3D mx::transpose", reshape_fn, a);
TIMEM("3D transpose", reshape_fn, a);
a = mx::transpose(a, {1, 2, 0});
a = transpose(a, {1, 2, 0});
shape = {8 * size, size, size};
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
TIMEM("3D transpose dims 1 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
a = broadcast_to(random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
a = broadcast_to(random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
}
void time_irregular_astype_1D() {
auto device = mx::default_device();
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = mx::random::uniform({size});
auto a = random::uniform({size});
a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", mx::astype, a, mx::int32, device);
TIMEM("1D strided", astype, a, int32, device);
}
void time_irregular_astype_2D() {
auto device = mx::default_device();
auto device = default_device();
int size = 2048;
std::vector<int> shape = {size, size};
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);
auto a = random::uniform(shape);
TIMEM("2D regular", astype, a, int32, device);
a = mx::transpose(a);
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
a = transpose(a);
TIMEM("2D transpose", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", astype, a, int32, device);
}
int main(int argc, char** argv) {
if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
set_default_device(use_gpu ? Device::gpu : Device::cpu);
}
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
std::cout << "Benchmarks for " << default_device() << std::endl;
time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D();

View File

@ -3,20 +3,20 @@
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
using namespace mlx::core;
void time_creation_ops() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
auto full_fp32 = [&]() { return full(shape, 3.3f); };
TIME(full_fp32);
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
TIME(zeros_fp32);
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
auto ones_fp32 = [&]() { return ones(shape, float32); };
TIME(ones_fp32);
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32);
}
@ -24,196 +24,194 @@ void time_type_conversions() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto device = mx::default_device();
auto device = default_device();
auto a = mx::zeros(shape, mx::float32);
mx::eval(a);
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
auto a = zeros(shape, float32);
eval(a);
TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("float32 to uint32", astype, a, uint32, device);
a = mx::zeros(shape, mx::int32);
mx::eval(a);
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
a = zeros(shape, int32);
eval(a);
TIMEM("int32 to float32", astype, a, float32, device);
a = mx::zeros(shape, mx::bool_);
mx::eval(a);
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
a = zeros(shape, bool_);
eval(a);
TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to uint32", astype, a, uint32, device);
}
void time_random_generation() {
int M = 2000;
int N = 500;
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
auto uniform = [&]() { return random::uniform({M, N}, float32); };
TIME(uniform);
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
auto normal = [&]() { return random::normal({M, N}, float32); };
TIME(normal);
}
void time_unary_ops() {
int M = 2000;
int N = 500;
auto device = mx::default_device();
auto device = default_device();
auto a = mx::random::normal({M, N});
mx::eval(a);
auto a = random::normal({M, N});
eval(a);
TIME(mlx::core::abs, a, device);
TIME(mx::negative, a, device);
TIME(mx::sign, a, device);
TIME(mx::square, a, device);
TIME(negative, a, device);
TIME(sign, a, device);
TIME(square, a, device);
TIME(mlx::core::sqrt, a, device);
TIME(mx::rsqrt, a, device);
TIME(rsqrt, a, device);
TIME(mlx::core::exp, a, device);
a = mx::random::uniform({M, N});
a = random::uniform({M, N});
TIME(mlx::core::log, a, device);
}
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = mx::random::randint(0, 2, {M, N, K});
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(mx::add, a, b, device);
TIME(mx::subtract, a, b, device);
TIME(mx::multiply, a, b, device);
TIME(mx::divide, a, b, device);
TIME(mx::maximum, a, b, device);
TIME(mx::minimum, a, b, device);
TIME(mx::where, condition, a, b, device);
TIME(add, a, b, device);
TIME(subtract, a, b, device);
TIME(multiply, a, b, device);
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
condition = mx::array({true});
b = mx::random::uniform({1});
mx::eval(b);
TIMEM("scalar", mx::add, a, b, device);
TIMEM("vector-scalar", mx::subtract, a, b, device);
TIMEM("scalar-vector", mx::subtract, b, a, device);
TIMEM("scalar", mx::multiply, a, b, device);
TIMEM("vector-scalar", mx::divide, a, b, device);
TIMEM("scalar-vector", mx::divide, b, a, device);
TIMEM("scalar-vector", mx::where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
mx::eval(a, b);
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}
void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50;
auto a = mx::random::uniform({M, N, O, P});
auto b = mx::random::uniform({M, N, O, P});
auto device = mx::default_device();
mx::eval(a, b);
TIMEM("non-strided", mx::add, a, b, device);
a = mx::transpose(a, {1, 0, 2, 3});
b = mx::transpose(b, {3, 2, 0, 1});
mx::eval(a, b);
TIMEM("strided", mx::add, a, b, device);
auto a = random::uniform({M, N, O, P});
auto b = random::uniform({M, N, O, P});
auto device = default_device();
eval(a, b);
TIMEM("non-strided", add, a, b, device);
a = transpose(a, {1, 0, 2, 3});
b = transpose(b, {3, 2, 0, 1});
eval(a, b);
TIMEM("strided", add, a, b, device);
}
void time_comparisons() {
int M = 1000, N = 100, K = 10;
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::equal, a, b, device);
TIME(mx::greater, a, b, device);
TIME(mx::greater_equal, a, b, device);
TIME(mx::less, a, b, device);
TIME(mx::less_equal, a, b, device);
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(equal, a, b, device);
TIME(greater, a, b, device);
TIME(greater_equal, a, b, device);
TIME(less, a, b, device);
TIME(less_equal, a, b, device);
}
void time_matvec() {
int M = 2000, N = 200;
auto a = mx::random::uniform({M, N});
auto b = mx::random::uniform({N});
auto c = mx::random::uniform({M});
mx::eval(a, b, c);
auto matvec = [&]() { return mx::matmul(a, b); };
auto a = random::uniform({M, N});
auto b = random::uniform({N});
auto c = random::uniform({M});
eval(a, b, c);
auto matvec = [&]() { return matmul(a, b); };
TIME(matvec);
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
TIME(matvec_transpose);
}
void time_matmul() {
int M = 1000, N = 1000, K = 1000;
auto a = mx::random::uniform({M, K});
auto b = mx::random::uniform({K, N});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::matmul, a, b, device);
auto a = random::uniform({M, K});
auto b = random::uniform({K, N});
auto device = default_device();
eval(a, b);
TIME(matmul, a, b, device);
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
TIME(transpose_matmul);
}
void time_reductions() {
auto a = mx::random::normal({10000, 1000});
mx::eval(a);
auto sum_all = [&a]() { return mx::sum(a, false); };
auto a = random::normal({10000, 1000});
eval(a);
auto sum_all = [&a]() { return sum(a, false); };
TIME(sum_all);
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
TIME(sum_along_0);
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
TIME(sum_along_1);
auto prod_all = [&a]() { return mx::prod(a, false); };
auto prod_all = [&a]() { return prod(a, false); };
TIME(prod_all);
auto all_true = [&a]() { return mx::all(a, false); };
auto all_true = [&a]() { return all(a, false); };
TIME(all_true);
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
auto all_along_0 = [&a]() { return all(a, 0, false); };
TIME(all_along_0);
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
auto all_along_1 = [&a]() { return all(a, 1, false); };
TIME(all_along_1);
auto any_true = [&a]() { return mx::any(a, false); };
auto any_true = [&a]() { return any(a, false); };
TIME(any_true);
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
TIME(argmin_along_1);
}
void time_gather_scatter() {
auto a = mx::random::normal({1000, 768});
mx::eval(a);
auto indices = mx::random::randint(0, 1000, {256});
mx::eval(indices);
auto a = random::normal({1000, 768});
eval(a);
auto indices = random::randint(0, 1000, {256});
eval(indices);
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
TIME(embedding_lookup);
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
mx::eval(indices);
indices = random::randint(0, 768 * 1000, {256 * 768});
eval(indices);
auto single_element_lookup = [&a, &indices]() {
return mx::take(a, indices);
};
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
TIME(single_element_lookup);
indices = mx::random::randint(0, 1000, {256});
auto updates = mx::random::normal({256, 1, 768});
mx::eval(indices, updates);
indices = random::randint(0, 1000, {256});
auto updates = random::normal({256, 1, 768});
eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@ -225,10 +223,10 @@ void time_gather_scatter() {
};
TIME(embedding_add);
a = mx::reshape(a, {-1});
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
updates = mx::random::normal({256 * 768, 1});
mx::eval(a, indices, updates);
a = reshape(a, {-1});
indices = random::randint(0, 768 * 1000, {768 * 256});
updates = random::normal({256 * 768, 1});
eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
@ -242,21 +240,21 @@ void time_gather_scatter() {
}
void time_divmod() {
auto a = mx::random::normal({1000});
auto b = mx::random::normal({1000});
mx::eval({a, b});
auto a = random::normal({1000});
auto b = random::normal({1000});
eval({a, b});
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops();
time_type_conversions();
time_unary_ops();

View File

@ -17,13 +17,14 @@
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< std::flush << std::setprecision(5) \
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args&&... args) {
double time_fn(F fn, Args... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));

View File

@ -144,13 +144,6 @@ def reduction(op, axis, x):
mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x):
ys = []
for i in range(100):
@ -387,6 +380,10 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.cpu:
mx.set_default_device(mx.cpu)
else:
@ -409,10 +406,6 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
@ -512,8 +505,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError("Unknown benchmark")

View File

@ -5,7 +5,6 @@ import os
import time
import torch
import torch.cuda
import torch.mps
@ -45,10 +44,8 @@ def bench(f, *args):
def sync_if_needed(x):
if x.device == torch.device("mps"):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad()
@ -102,14 +99,6 @@ def reduction(op, axis, x):
sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad()
def softmax(axis, x):
ys = []
@ -196,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.mish(y)
return torch.nn.functional.mish(y)
sync_if_needed(x)
@ -294,14 +283,6 @@ def topk(axis, x):
sync_if_needed(x)
@torch.no_grad()
def step_function(x):
y = x
for i in range(100):
y = torch.where(y < 0, 0, 1)
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
@ -350,12 +331,12 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
torch.set_num_threads(1)
device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
device = "cpu" if args.cpu else "mps"
types = args.dtype
if not types:
@ -373,10 +354,6 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
@ -469,14 +446,5 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
raise ValueError("Unknown benchmark")

View File

@ -16,9 +16,7 @@ def run_or_raise(*args, **kwargs):
result = run(*args, capture_output=True, **kwargs)
return float(result.stdout)
except ValueError:
raise ValueError(
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
)
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
def compare(args):

View File

@ -9,6 +9,7 @@ from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@ -50,6 +51,7 @@ def bench_gelu():
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)

View File

@ -1,123 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -1,127 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -1,143 +0,0 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal([10, 256, 256, 3])
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(10, 3, 256, 256, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 20
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@ -1,129 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -1,110 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -1,143 +0,0 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal(shape)
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(*shape, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 10
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@ -1,116 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose3d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -28,11 +28,11 @@ def bench(f, a, b):
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
y = mx.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
mx.eval(ys)
return ys
@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
y = torch.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
torch.mps.synchronize()
return ys
@ -53,12 +53,11 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
@ -68,15 +67,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
f_mx = make_mx_conv_2D(strides, padding)
f_pt = make_pt_conv_2D(strides, padding)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
@ -85,7 +84,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
@ -96,40 +95,35 @@ if __name__ == "__main__":
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
for N, H, W, C, kH, kW, O, strides, padding in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
N, H, W, C, kH, kW, O, strides, padding, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -1,135 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -1,107 +0,0 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -1,66 +0,0 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
mpirun -n 2 python /path/to/distributed_bench.py
"""
import time
import mlx.core as mx
def time_fn(fn, *args, **kwargs):
msg = kwargs.pop("msg", None)
world = mx.distributed.init()
if world.rank() == 0:
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
mx.eval(fn(*args, **kwargs))
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = mx.eval(fn(*args, **kwargs))
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
if world.rank() == 0:
print(f"{msec:.5f} msec")
def time_all_sum():
shape = (4096,)
x = mx.random.uniform(shape=shape)
mx.eval(x)
def sine(x):
for _ in range(20):
x = mx.sin(x)
return x
time_fn(sine, x)
def all_sum_plain(x):
for _ in range(20):
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_plain, x)
def all_sum_with_sine(x):
for _ in range(20):
x = mx.sin(x)
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_with_sine, x)
if __name__ == "__main__":
time_all_sum()

View File

@ -1,84 +0,0 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
import numpy as np
def timeit(fn, its=100, args=[]):
for _ in range(5):
fn(*args)
tic = time.perf_counter()
for _ in range(its):
fn(*args)
toc = time.perf_counter()
return 1e3 * (toc - tic) / its
def time_little_einsum_path():
subscripts = "ik,kj->ij"
x = mx.ones((32, 32))
y = mx.ones((32, 32))
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
x = np.array(x)
y = np.array(y)
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
print("Timing little einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_big_einsum_path():
chars = list("abcdefgh")
char_to_dim = {c: v for v, c in enumerate(chars)}
num_inputs = 10
inputs = []
subscripts = []
for _ in range(num_inputs):
subscript = np.random.choice(chars, size=5, replace=False).tolist()
subscripts.append("".join(subscript))
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
subscripts = ",".join(subscripts)
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
inputs = [mx.array(x) for x in inputs]
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
print("Timing big einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_attention():
def regular_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
scores = mx.softmax(scores, axis=-1)
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
mx.eval(output)
def einsum_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
scores = mx.softmax(scores, axis=-1)
output = mx.einsum("ijtu,iujk->itjk", scores, values)
mx.eval(output)
x = mx.random.uniform(shape=(8, 512, 32, 128))
regular_time = timeit(regular_attention, args=(x,))
ein_time = timeit(einsum_attention, args=(x,))
print("Timing einsum attention...")
print(f"Regular ... {regular_time:.3f} ms")
print(f"Einsum ... {ein_time:.3f} ms")
if __name__ == "__main__":
time_little_einsum_path()
time_big_einsum_path()
time_attention()

View File

@ -1,118 +0,0 @@
# Copyright © 2024 Apple Inc.
import matplotlib
import mlx.core as mx
import numpy as np
import sympy
import torch
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def bandwidth_gb(runtime_ms, system_size):
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
bytes_per_gb = 1e9
ms_per_s = 1e3
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
def fft_mlx(x):
if dim == 1:
out = mx.fft.fft(x)
elif dim == 2:
out = mx.fft.fft2(x)
mx.eval(out)
return out
def fft_mps(x):
if dim == 1:
out = torch.fft.fft(x)
elif dim == 2:
out = torch.fft.fft2(x)
torch.mps.synchronize()
return out
bandwidths = []
for n in fft_sizes:
batch_size = system_size // n**dim
shape = [batch_size] + [n for _ in range(dim)]
if backend == "mlx":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = mx.array(x_np)
mx.eval(x)
fft = fft_mlx
elif backend == "mps":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = torch.tensor(x_np, device="mps")
torch.mps.synchronize()
fft = fft_mps
else:
raise NotImplementedError()
runtime_ms = measure_runtime(fft, x=x)
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
print(n, bandwidth)
bandwidths.append(bandwidth)
return np.array(bandwidths)
def time_fft():
x = np.array(range(2, 512))
system_size = int(2**26)
print("MLX GPU")
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
print("MPS GPU")
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
print("CPU")
system_size = int(2**20)
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
x = np.array(x)
all_indices = x - x[0]
radix_2to13 = (
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
)
bluesteins = (
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
)
for indices, name in [
(all_indices, "All"),
(radix_2to13, "Radix 2-13"),
(bluesteins, "Bluestein's"),
]:
# plot bandwidths
print(name)
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
plt.title(f"MLX FFT Benchmark -- {name}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"{name}.png")
plt.clf()
av_gpu_bandwidth = np.mean(gpu_bandwidths)
av_mps_bandwidth = np.mean(mps_bandwidths)
av_cpu_bandwidth = np.mean(cpu_bandwidths)
print("Average bandwidths:")
print("GPU:", av_gpu_bandwidth)
print("MPS:", av_mps_bandwidth)
print("CPU:", av_cpu_bandwidth)
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
print("Percent MLX faster than MPS: ", portion_faster * 100)
if __name__ == "__main__":
time_fft()

View File

@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch

View File

@ -1,74 +0,0 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@ -1,84 +0,0 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@ -1,70 +0,0 @@
import argparse
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def had(x):
y = mx.hadamard_transform(x)
mx.eval(y)
def copy(x):
y = x + 1.0
mx.eval(y)
def run(dtype):
system_size = 2**26
outputs = {}
for test_fn in (had, copy):
for m in [1, 12, 20, 28]:
if test_fn == copy:
key = "copy"
elif m == 1:
key = "had_2^k"
else:
key = "had_m*2^k"
outputs.setdefault(key, {})
for k in range(7, 14):
n = m * 2**k
if n > 2**15:
continue
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
x = mx.array(x_np)
runtime_ms = measure_runtime(test_fn, x=x)
bytes_per_gb = 1e9
ms_per_s = 1e3
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
bandwidth_gb = (
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
)
print(n, bandwidth_gb)
outputs[key][n] = bandwidth_gb
colors = {
"copy": "black",
"had_2^k": "steelblue",
"had_m*2^k": "skyblue",
}
for key, output in outputs.items():
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"bench_{dtype.__name__}.png")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
dtype = np.float16 if args.fp16 else np.float32
run(dtype)

View File

@ -1,82 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def layer_norm(x, w, b, eps):
ot = x.dtype
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
y = (x - mu) * mx.rsqrt(v + eps)
if w is not None:
y = y * w
if b is not None:
y = y + b
return y
def time_layer_norm(N, dt):
L = 1024
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_loop(f, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_grad_loop, g1, x, w, b)
time_fn(layer_norm_grad_loop, g2, x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_grad_x_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_grad_x_loop, g1, x)
time_fn(layer_norm_grad_x_loop, g2, x)
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
if __name__ == "__main__":
for dt in [mx.float32, mx.float16, mx.bfloat16]:
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)

View File

@ -1,63 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
y = (x * n).astype(ot)
if w is not None:
y = y * w
return y
def time_rms_norm():
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1))
g2 = mx.grad(f2, argnums=(0, 1))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x, w):
gx, gw = x, w
for _ in range(32):
gx, gw = g(gx, gw, y)
return gx, gw
time_fn(rms_norm_loop, g1, x, w)
time_fn(rms_norm_loop, g2, x, w)
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(rms_norm_loop, g1, x)
time_fn(rms_norm_loop, g2, x)
time_fn(rms_norm_loop, mx.compile(g1), x)
time_fn(rms_norm_loop, mx.compile(g2), x)
if __name__ == "__main__":
time_rms_norm()

View File

@ -6,21 +6,21 @@ from time_utils import time_fn
def time_rope():
rope = nn.RoPE(64)
rope = nn.RoPE(4096)
# vec
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x, offset=100)
x = rope(x)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):

View File

@ -9,7 +9,7 @@ from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[tuple(idx)] = x
dst[*idx] = x
mx.eval(dst)
idx = []
@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def scatter(dst, x, idx, device):
dst[tuple(idx)] = x
def gather(dst, x, idx, device):
dst[*idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
@ -54,7 +54,7 @@ if __name__ == "__main__":
(100_000, 64),
(1_000_000, 64),
(100_000,),
(200_000,),
(2_000_00,),
(20_000_000,),
(10000, 64),
(100, 64),
@ -91,6 +91,6 @@ if __name__ == "__main__":
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@ -1,223 +0,0 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
def bench(f, *args):
for i in range(N_warmup):
f(*args)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(*args)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
np_dtype = getattr(np, dtype)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if mask is not None:
if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_fused_attn(q, k, v, scale, mask):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose:
q_t = mx.transpose(q, (0, 2, 1, 3))
k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
q_out = q
for i in range(N_iter_func):
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
mx.eval(q_out)
return q_out
def bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
):
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
)
time_mlx_unfused = bench(
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
time_mlx_fused = bench(
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
o_mlx_unfused = do_attention(
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
atol = 1e-5 if dtype == "float32" else 2e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 8),
( 1, 2048, 2048, 64, 32, 8),
( 1, 4096, 4096, 64, 32, 8),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 8),
( 1, 2048, 2048, 80, 32, 8),
( 1, 4096, 4096, 80, 32, 8),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 8),
( 1, 2048, 2048, 128, 32, 8),
( 1, 4096, 4096, 128, 32, 8),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
masks = [None, "bool", "causal"]
print(
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
)
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
for mask_in in masks:
time_mlx_fused, time_mlx_unfused = bench_shape(
B,
qsl,
ksl,
head_dim,
n_q_heads,
n_kv_heads,
dtype,
transpose,
mask_in,
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@ -1,95 +0,0 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
L = 16384
H = 32
H_k = H // 4
D = 128
V = 128
dtype = mx.float16
loops = 10
def upproject(x, w):
if w is None:
return x
else:
return x @ w.T
def attention(q, k, v, mask=None, w=None):
def _sdpa(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
_, _, _, V = v.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
if mask is not None:
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
s = mx.where(m, s, mx.finfo(s.dtype).min)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, V)
for i in range(loops):
q = _sdpa(q, k, v)
q = upproject(q, w)
return q
def sdpa(q, k, v, mask=None, w=None):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q = upproject(q, w)
return q
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mx.eval(q, k, v, w)
time_fn(attention, q, k, v, w=w)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mx.eval(q, k, v, w)
time_fn(sdpa, q, k, v, w=w)
def time_self_attention_sdpa_with_mask():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mask = mx.full((L,), True)
mask[L // 2 :] = False
mx.eval(q, k, v, mask, w)
def sdpa_mask(*args):
return sdpa(*args, mask=mask, w=w)
def attention_mask(*args):
return attention(*args, mask=mask, w=w)
time_fn(attention_mask, q, k, v)
time_fn(sdpa_mask, q, k, v)
if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()
time_self_attention_sdpa_with_mask()

View File

@ -1,55 +0,0 @@
import time
import mlx.core as mx
rank = mx.distributed.init().rank()
def timeit(fn, a):
# warmup
for _ in range(5):
mx.eval(fn(a))
its = 10
tic = time.perf_counter()
for _ in range(its):
mx.eval(fn(a))
toc = time.perf_counter()
ms = 1000 * (toc - tic) / its
return ms
def all_reduce_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_sum(x)
x = x - 1
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
def all_gather_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_gather(x)[0]
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All gather: time per iteration {ms:.6f} (ms)")
if __name__ == "__main__":
all_reduce_benchmark()
all_gather_benchmark()

View File

@ -1,50 +1,56 @@
include(CMakeParseArguments)
# clang format off
#
# ##############################################################################
###############################################################################
# Build metal library
#
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
#
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
# Args:
# TARGET: Custom target to be added for the metal library
# TITLE: Name of the .metallib
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of dependency files (like headers)
#
# clang format on
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cmake_parse_arguments(
MTLLIB
""
"${oneValueArgs}"
"${multiValueArgs}"
${ARGN}
)
# Set output
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metallib build command
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND
xcrun -sdk macosx metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS}
${MTLLIB_SOURCES}
-o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
COMMAND_EXPAND_LISTS
COMMENT "Building ${MTLLIB_TITLE}.metallib"
VERBATIM)
VERBATIM
)
# Add metallib custom target
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
add_custom_target(
${MTLLIB_TARGET}
DEPENDS
${MTLLIB_BUILD_TARGET}
)
endmacro(mlx_build_metallib)
endmacro(mlx_build_metallib)

View File

@ -1,50 +0,0 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = NO
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@ -2,16 +2,12 @@
### Setup (do once)
Install Doxygen:
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
for example with `conda`:
```
brew install doxygen
```
Install Python packages:
```
pip install -r requirements.txt
conda install sphinx
pip install sphinx-book-theme
```
### Build
@ -19,7 +15,7 @@ pip install -r requirements.txt
Build the docs from `mlx/docs/`
```
doxygen && make html
make html
```
View the docs by running a server in `mlx/docs/build/html/`:

View File

@ -1,4 +0,0 @@
sphinx
breathe
sphinx-book-theme
mlx

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 746 KiB

View File

@ -1,20 +0,0 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != "__init__" %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@ -10,7 +10,7 @@ import mlx.core as mx
# -- Project information -----------------------------------------------------
project = "MLX"
copyright = "2023, Apple"
copyright = "2023, MLX Contributors"
author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3])
release = version
@ -22,7 +22,6 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"breathe",
]
python_use_unqualified_type_names = True
@ -30,20 +29,16 @@ autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"https://docs.python.org/3": None,
"https://numpy.org/doc/stable/": None,
}
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"]
html_static_path = ["_static"]
source_suffix = ".rst"
main_doc = "index"
master_doc = "index"
highlight_language = "python"
pygments_style = "sphinx"
add_module_names = False
# -- Options for HTML output -------------------------------------------------
@ -60,39 +55,7 @@ html_theme_options = {
},
}
html_favicon = html_theme_options["logo"]["image_light"]
# -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc"
def setup(app):
from sphinx.util import inspect
wrapped_isfunc = inspect.isfunction
def isfunc(obj):
type_name = str(type(obj))
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
return True
return wrapped_isfunc(obj)
inspect.isfunction = isfunc
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
latex_elements = {
"preamble": r"""
\usepackage{enumitem}
\setlistdepth{5}
\setlist[itemize,1]{label=$\bullet$}
\setlist[itemize,2]{label=$\bullet$}
\setlist[itemize,3]{label=$\bullet$}
\setlist[itemize,4]{label=$\bullet$}
\setlist[itemize,5]{label=$\bullet$}
\renewlist{itemize}{itemize}{5}
""",
}

View File

@ -3,5 +3,4 @@
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@ -1,445 +0,0 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note::
Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using:
* The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
These will be added as function arguments.
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
Putting this all together, the generated function signature for ``myexp`` is as follows:
.. code-block:: cpp
template <typename T>
[[kernel]] void custom_kernel_myexp_float(
const device float16_t* inp [[buffer(0)]],
device float16_t* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
}
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides
-------------------
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
is ``True`` by default. This will copy the array inputs if needed
before the kernel is launched to ensure that the memory layout is row
contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Complex Example
-----------------------------
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
``55.7ms -> 6.7ms => 8x speed up``
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
* ``atomic_outputs=True``
Designate all of the kernel outputs as ``atomic`` in the function signature.
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
We can then implement the backwards pass as follows:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T gix_mult = W / 2;
T giy_mult = H / 2;
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:
``676.4ms -> 16.7ms => 40x speed up``

File diff suppressed because it is too large Load Diff

View File

@ -1,68 +0,0 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
# that the path trace_file does not already exist.
mx.metal.start_capture(trace_file)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@ -1,121 +0,0 @@
.. _mlx_in_cpp:
Using MLX in C++
================
You can use MLX in a C++ project with CMake.
.. note::
This guide is based one the following `example using MLX in C++
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
First install MLX:
.. code-block:: bash
pip install -U mlx
You can also install the MLX Python package from source or just the C++
library. For more information see the :ref:`documentation on installing MLX
<build_and_install>`.
Next make an example program in ``example.cpp``:
.. code-block:: C++
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}
The next step is to setup a CMake file in ``CMakeLists.txt``:
.. code-block:: cmake
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Depending on how you installed MLX, you may need to tell CMake where to
find it.
If you installed MLX with Python, then add the following to the CMake file:
.. code-block:: cmake
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
If you installed the MLX C++ package to a system path, then CMake should be
able to find it. If you installed it to a non-standard location or CMake can't
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
.. code-block:: cmake
set(MLX_ROOT "/path/to/mlx/")
Next, instruct CMake to find MLX:
.. code-block:: cmake
find_package(MLX CONFIG REQUIRED)
Finally, add the ``example.cpp`` program as an executable and link MLX.
.. code-block:: cmake
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)
You can build the example with:
.. code-block:: bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
And run it with:
.. code-block:: bash
./build/example
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
.. list-table:: Package Variables
:widths: 20 20
:header-rows: 1
* - Variable
- Description
* - MLX_FOUND
- ``True`` if MLX is found
* - MLX_INCLUDE_DIRS
- Include directory
* - MLX_LIBRARIES
- Libraries to link against
* - MLX_CXX_FLAGS
- Additional compiler flags
* - MLX_BUILD_ACCELERATE
- ``True`` if MLX was built with Accelerate
* - MLX_BUILD_METAL
- ``True`` if MLX was built with Metal

View File

@ -15,7 +15,7 @@ module to concisely define the model architecture.
Attention layer
^^^^^^^^^^^^^^^^
We will start with the Llama attention layer which notably uses the RoPE
We will start with the llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.

View File

@ -64,7 +64,7 @@ set:
Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as ``mnist``.
we will import as `mnist`.
.. code-block:: python

View File

@ -43,9 +43,7 @@ are the CPU and GPU.
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
usage/export
.. toctree::
:caption: Examples
@ -60,20 +58,15 @@ are the CPU and GPU.
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/export
python/ops
python/random
python/transforms
python/fast
python/fft
python/linalg
python/metal
python/memory_management
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::
@ -87,6 +80,3 @@ are the CPU and GPU.
:maxdepth: 1
dev/extensions
dev/metal_debugger
dev/custom_metal_kernels
dev/mlx_in_cpp

View File

@ -1,5 +1,3 @@
.. _build_and_install:
Build and Install
=================
@ -16,11 +14,11 @@ silicon computer is
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
- macOS >= 13.5
- Using a native Python >= 3.8
- macOS >= 13.3
.. note::
MLX is only available on devices running macOS >= 13.5
MLX is only available on devices running macOS >= 13.3
It is highly recommended to use macOS 14 (Sonoma)
@ -30,16 +28,6 @@ MLX is also available on conda-forge. To install MLX with conda do:
conda install conda-forge::mlx
CUDA
^^^^
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell
pip install mlx-cuda
Troubleshooting
^^^^^^^^^^^^^^^
@ -65,8 +53,8 @@ Build Requirements
^^^^^^^^^^^^^^^^^^
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
@ -75,8 +63,6 @@ Build Requirements
Python API
^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
@ -84,43 +70,44 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Then simply build and install MLX using pip:
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
pip install "pybind11[global]"
conda install pybind11
brew install pybind11
For developing, install the package with development dependencies, and use an
editable install:
Then simply build and install it using pip:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
Once the development dependencies are installed, you can build faster with:
For developing use an editable install:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
Run the tests with:
To make sure the install is working run the tests with:
.. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your
IDE:
Optional: Install stubs to enable auto completions and type checking from your IDE:
.. code-block:: shell
pip install ".[dev]"
python setup.py generate_stubs
C++ API
^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start
@ -136,7 +123,7 @@ Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
cmake .. && make -j
Run tests with:
@ -155,7 +142,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
@ -169,112 +156,31 @@ should point to the path to the built metal library.
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
Further, you can use the following command to find out which
macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
.. code-block:: shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~
@ -296,12 +202,12 @@ Then set the active developer directory:
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
x86 Shell
~~~~~~~~~
.. _build shell:
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
@ -325,4 +231,4 @@ Also check that cmake is using the correct architecture:
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cache with ``rm -rf build/`` and try again.
wipe your build cahce with ``rm -rf build/`` and try again.

View File

@ -10,42 +10,27 @@ Array
array
array.astype
array.at
array.item
array.tolist
array.dtype
array.itemsize
array.nbytes
array.ndim
array.shape
array.size
array.real
array.imag
Dtype
array.abs
array.all
array.any
array.argmax
array.argmin
array.conj
array.cos
array.cummax
array.cummin
array.cumprod
array.cumsum
array.diag
array.diagonal
array.dtype
array.exp
array.flatten
array.log
array.log10
array.log1p
array.log2
array.logcumsumexp
array.logsumexp
array.max
array.mean
array.min
array.moveaxis
array.prod
array.reciprocal
array.reshape
@ -55,11 +40,7 @@ Array
array.split
array.sqrt
array.square
array.squeeze
array.std
array.sum
array.swapaxes
array.transpose
array.T
array.var
array.view

View File

@ -1,5 +1,7 @@
.. _data_types:
:orphan:
Data Types
==========
@ -42,37 +44,9 @@ The default floating point type is ``float32`` and the default integer type is
* - ``int64``
- 8
- 64-bit signed integer
* - ``bfloat16``
- 2
- 16-bit brain float (e8, m7)
* - ``float16``
- 2
- 16-bit IEEE float (e5, m10)
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
* - ``float32``
- 4
- 32-bit float
* - ``float64``
- 4
- 64-bit double
* - ``complex64``
- 8
- 64-bit complex float
.. note::
Arrays with type ``float64`` only work with CPU operations. Using
``float64`` arrays on the GPU will result in an exception.
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype
finfo

View File

@ -16,4 +16,3 @@ Devices and Streams
new_stream
set_default_stream
stream
synchronize

View File

@ -1,22 +0,0 @@
.. _distributed:
.. currentmodule:: mlx.core.distributed
Distributed Communication
==========================
MLX provides a distributed communication package using MPI. The MPI library is
loaded at runtime; if MPI is available then distributed communication is also
made available.
.. autosummary::
:toctree: _autosummary
Group
is_available
init
all_sum
all_gather
send
recv
recv_like

View File

@ -1,14 +0,0 @@
.. _export:
Export Functions
================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
export_function
import_function
exporter
export_to_dot

View File

@ -1,15 +0,0 @@
.. _fast:
Fast
====
.. currentmodule:: mlx.core.fast
.. autosummary::
:toctree: _autosummary
rms_norm
layer_norm
rope
scaled_dot_product_attention
metal_kernel

View File

@ -20,5 +20,3 @@ FFT
irfft2
rfftn
irfftn
fftshift
ifftshift

View File

@ -5,23 +5,8 @@ Linear Algebra
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
.. autosummary::
:toctree: _autosummary
inv
tri_inv
norm
cholesky
cholesky_inv
cross
qr
svd
eigvals
eig
eigvalsh
eigh
lu
lu_factor
pinv
solve
solve_triangular

View File

@ -1,16 +0,0 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@ -3,10 +3,12 @@ Metal
.. currentmodule:: mlx.core.metal
.. autosummary::
.. autosummary::
:toctree: _autosummary
is_available
device_info
start_capture
stop_capture
get_active_memory
get_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit

View File

@ -173,8 +173,6 @@ In detail:
:toctree: _autosummary
value_and_grad
quantize
average_gradients
.. toctree::

View File

@ -13,13 +13,10 @@ simple functions.
:template: nn-module-template.rst
elu
celu
gelu
gelu_approx
gelu_fast_approx
glu
hard_shrink
hard_tanh
hardswish
leaky_relu
log_sigmoid
@ -32,7 +29,6 @@ simple functions.
sigmoid
silu
softmax
softmin
softplus
softshrink
step

View File

@ -12,58 +12,32 @@ Layers
ALiBi
AvgPool1d
AvgPool2d
AvgPool3d
BatchNorm
CELU
Conv1d
Conv2d
Conv3d
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
Dropout
Dropout2d
Dropout3d
Embedding
ELU
GELU
GLU
GroupNorm
GRU
HardShrink
HardTanh
Hardswish
InstanceNorm
LayerNorm
LeakyReLU
Linear
LogSigmoid
LogSoftmax
LSTM
MaxPool1d
MaxPool2d
MaxPool3d
Mish
MultiHeadAttention
PReLU
QuantizedEmbedding
QuantizedLinear
RMSNorm
ReLU
ReLU6
RNN
RoPE
SELU
Sequential
Sigmoid
SiLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Softsign
Softmax
Softplus
Step
Tanh
Transformer
Upsample
Upsample

View File

@ -30,7 +30,6 @@ Module
Module.named_modules
Module.parameters
Module.save_weights
Module.set_dtype
Module.train
Module.trainable_parameters
Module.unfreeze

View File

@ -5,14 +5,13 @@ Operations
.. currentmodule:: mlx.core
.. autosummary::
.. autosummary::
:toctree: _autosummary
abs
add
addmm
all
allclose
allclose
any
arange
arccos
@ -20,80 +19,49 @@ Operations
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argpartition
argsort
array_equal
as_strided
atleast_1d
atleast_2d
atleast_3d
bitwise_and
bitwise_invert
bitwise_or
bitwise_xor
block_masked_mm
broadcast_arrays
broadcast_to
ceil
clip
concatenate
contiguous
conj
conjugate
convolve
conv1d
conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
conv_general
cos
cosh
cummax
cummin
cumprod
cumsum
degrees
dequantize
diag
diagonal
divide
divmod
einsum
einsum_path
equal
erf
erfinv
exp
expm1
expand_dims
eye
flatten
floor
floor_divide
full
gather_mm
gather_qmm
greater
greater_equal
hadamard_transform
identity
imag
inner
isfinite
isclose
isinf
isnan
isneginf
isposinf
issubdtype
kron
left_shift
isneginf
isinf
less
less_equal
linspace
@ -103,7 +71,6 @@ Operations
log10
log1p
logaddexp
logcumsumexp
logical_not
logical_and
logical_or
@ -112,32 +79,22 @@ Operations
max
maximum
mean
meshgrid
min
minimum
moveaxis
multiply
nan_to_num
negative
not_equal
ones
ones_like
outer
partition
pad
power
prod
put_along_axis
quantize
quantized_matmul
radians
real
reciprocal
remainder
repeat
reshape
right_shift
roll
round
rsqrt
save
@ -149,8 +106,6 @@ Operations
sign
sin
sinh
slice
slice_update
softmax
sort
split
@ -158,7 +113,6 @@ Operations
square
squeeze
stack
std
stop_gradient
subtract
sum
@ -170,14 +124,11 @@ Operations
tensordot
tile
topk
trace
transpose
tri
tril
triu
unflatten
var
view
where
zeros
zeros_like

View File

@ -1,7 +1,5 @@
.. _optimizers:
.. currentmodule:: mlx.optimizers
Optimizers
==========
@ -31,48 +29,8 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree::
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
clip_grad_norm

View File

@ -18,4 +18,3 @@ Common Optimizers
AdamW
Adamax
Lion
MultiOptimizer

View File

@ -38,11 +38,8 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
gumbel
key
normal
multivariate_normal
randint
seed
split
truncated_normal
uniform
laplace
permutation

View File

@ -9,9 +9,7 @@ Transforms
:toctree: _autosummary
eval
async_eval
compile
custom_function
disable_compile
enable_compile
grad

View File

@ -19,5 +19,3 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten
tree_unflatten
tree_map
tree_map_with_path
tree_reduce

View File

@ -33,12 +33,12 @@ Let's start with a simple example:
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
@ -96,7 +96,7 @@ element-wise operations:
.. code-block:: python
def gelu(x):
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
@ -136,6 +136,13 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
@ -280,7 +287,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
print(fun(mx.array(1.0)))
Compiling Training Graphs
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
@ -290,7 +297,7 @@ full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@ -323,7 +330,7 @@ To start, here is the simple example without any compilation:
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@ -348,7 +355,7 @@ appropriate input and output captures. Here's the same example but compiled:
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@ -403,7 +410,7 @@ Compiling transformed functions works just as expected:
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
@ -421,77 +428,3 @@ the most opportunity to optimize the computation graph:
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)
.. _shapeless_compile:
Shapeless Compilation
---------------------
When the shape of an input to a compiled function changes, the function is
recompiled. You can compile a function once and run it on inputs with
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
case changes to the shapes of the inputs do not cause the function to be
recompiled.
.. code-block:: python
def fun(x, y):
return mx.abs(x + y)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.array(1.0)
y = mx.array(-2.0)
# Firt call compiles the function
print(compiled_fun(x, y))
# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))
Use shapeless compilations carefully. Since compilation is not triggered when
shapes change, any graphs which are conditional on the input shapes will not
work as expected. Shape-dependent computations are common and sometimes subtle
to detect. For example:
.. code-block:: python
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)
The second call to the ``compiled_fun`` fails because of the call to
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
.. code-block:: python
def fun(x):
return x.flatten(0, 1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Ok
out = compiled_fun(x)

View File

@ -1,344 +0,0 @@
.. _usage_distributed:
Distributed Communication
=========================
.. currentmodule:: mlx.core.distributed
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support two different communication backends:
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
.. note::
Some operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
A distributed program in MLX is as simple as:
.. code:: python
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. However, when this script is run with ``python`` only
one process is launched and no distributed communication takes place. Namely,
all operations in ``mx.distributed`` are noops when the distributed group has a
size of one. This property allows us to avoid code that checks if we are in a
distributed setting similar to the one below:
.. code:: python
import mlx.core as mx
x = ...
world = mx.distributed.init()
# No need for the check we can simply do x = mx.distributed.all_sum(x)
if world.size() > 1:
x = mx.distributed.all_sum(x)
Running Distributed Programs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
Continuing with our initial example we can run it on localhost with 4 processes using
.. code:: shell
$ mlx.launch -n 4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
We can also run it on some remote hosts by providing their IPs (provided that
the script exists on all hosts and they are reachable by ssh)
.. code:: shell
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
information on using ``mlx.launch``.
Selecting Backend
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
both fail then a singleton group is created.
.. note::
After a distributed backend is successfully initialized :func:`init` will
return **the same backend** if called without arguments or with backend set to
``any``.
The following examples aim to clarify the backend initialization logic in MLX:
.. code:: python
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
world = mx.distributed.init(backend="mpi")
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
# Case 2: Initialize any backend
world = mx.distributed.init(backend="any") # equivalent to no arguments
world2 = mx.distributed.init() # same as above
# Case 3: Initialize both backends at the same time
world_mpi = mx.distributed.init(backend="mpi")
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
Training Example
----------------
In this section we will adapt an MLX training loop to support data parallel
distributed training. Namely, we will average the gradients across a set of
hosts before applying them to the model.
Our training loop looks like the following code snippet if we omit the model,
dataset and optimizer initialization.
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
All we have to do to average the gradients across machines is perform an
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
have to :func:`mlx.utils.tree_map` the gradients with following function.
.. code:: python
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
Putting everything together our training loop step looks as follows with
everything else remaining the same.
.. code:: python
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init().size()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads
)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Utilizing ``nn.average_gradients``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Although the code example above works correctly; it performs one communication
per gradient. It is significantly more efficient to aggregate several gradients
together and perform fewer communication steps.
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
almost identical to the example above:
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
``mpirun`` as expected. However, in the following examples we will be using
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
.. code:: shell
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
Tuning MPI All Reduce
^^^^^^^^^^^^^^^^^^^^^
.. note::
For faster all reduce consider using the ring backend either with Thunderbolt
connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing ``--mca btl_tcp_links N``.
Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
Getting Started with Ring
-------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Defining a Ring
^^^^^^^^^^^^^^^
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
Thunderbolt Ring
^^^^^^^^^^^^^^^^
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
To validate your connection without configuring anything
``mlx.distributed_config`` can also plot the ring using DOT format.
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.

View File

@ -1,288 +0,0 @@
.. _export_usage:
Exporting Functions
===================
.. currentmodule:: mlx.core
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)
To export a function, provide sample input arrays that the function
can be called with. The data doesn't matter, but the shapes and types of the
arrays do. In the above example we exported ``fun`` with two ``float32``
scalar arrays. We can then import the function and run it:
.. code-block:: python
add_fun = mx.import_function("add.mlxfn")
out, = add_fun(mx.array(1.0), mx.array(2.0))
# Prints: array(3, dtype=float32)
print(out)
out, = add_fun(mx.array(1.0), mx.array(3.0))
# Prints: array(4, dtype=float32)
print(out)
# Raises an exception
add_fun(mx.array(1), mx.array(3.0))
# Raises an exception
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
shapes and types of the inputs are different than the shapes and types of the
example inputs we exported the function with.
Also notice that even though the original ``fun`` returns a single output
array, the imported function always returns a tuple of one or more arrays.
The inputs to :func:`export_function` and to an imported function can be
specified as variable positional arguments or as a tuple of arrays:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
# Same as above
mx.export_function("add.mlxfn", fun, (x, y))
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y)
# Also ok
out, = imported_fun((x, y))
You can pass example inputs to functions as positional or keyword arguments. If
you use keyword arguments to export the function, then you have to use the same
keyword arguments when calling the imported function.
.. code-block:: python
def fun(x, y):
return x + y
# One argument to fun is positional, the other is a kwarg
mx.export_function("add.mlxfn", fun, x, y=y)
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y=y)
# Also ok
out, = imported_fun((x,), {"y": y})
# Raises since the keyword argument is missing
out, = imported_fun(x, y)
# Raises since the keyword argument has the wrong key
out, = imported_fun(x, z=y)
Exporting Modules
-----------------
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
in the exported function. Here's an example:
.. code-block:: python
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x):
return model(x)
mx.export_function("model.mlxfn", call, mx.zeros(4))
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
parameters are also saved to the ``model.mlxfn`` file.
.. note::
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
If you only want to export the ``Module.__call__`` function without the
parameters, pass them as inputs to the ``call`` wrapper:
.. code-block:: python
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x, **params):
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
Shapeless Exports
-----------------
Just like :func:`compile`, functions can also be exported for dynamically shaped
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
to export a function which can be used for inputs with variable shapes:
.. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
``imported_abs`` would raise an exception with a shape mismatch.
Shapeless exporting works the same as shapeless compilation and should be
used carefully. See the :ref:`documentation on shapeless compilation
<shapeless_compile>` for more information.
Exporting Multiple Traces
-------------------------
In some cases, functions build different computation graphs for different
input arguments. A simple way to manage this is to export to a new file with
each set of inputs. This is a fine option in many cases. But it can be
suboptimal if the exported functions have a large amount of duplicate constant
data (for example the parameters of a :obj:`mlx.nn.Module`).
The export API in MLX lets you export multiple traces of the same function to
a single file by creating an exporting context manager with :func:`exporter`:
.. code-block:: python
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
exporter(mx.array(1.0))
exporter(mx.array(1.0), y=mx.array(0.0))
imported_function = mx.import_function("fun.mlxfn")
# Call the function with y=None
out, = imported_function(mx.array(1.0))
print(out)
# Call the function with y specified
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
Transformations with Imported Functions
---------------------------------------
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
on imported functions just like regular Python functions:
.. code-block:: python
def fun(x):
return mx.sin(x)
x = mx.array(0.0)
mx.export_function("sine.mlxfn", fun, x)
imported_fun = mx.import_function("sine.mlxfn")
# Take the derivative of the imported function
dfdx = mx.grad(lambda x: imported_fun(x)[0])
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
Importing Functions in C++
--------------------------
Importing and running functions in C++ is basically the same as importing and
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
setup a simple C++ project that uses MLX as a library.
Next, export a simple function from Python:
.. code-block:: python
def fun(x, y):
return mx.exp(x + y)
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("fun.mlxfn", fun, x, y)
Import and run the function in C++ with only a few lines of code:
.. code-block:: c++
auto fun = mx::import_function("fun.mlxfn");
auto inputs = {mx::array(1.0), mx::array(1.0)};
auto outputs = fun(inputs);
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.
More Examples
-------------
Here are a few more complete examples exporting more complex functions from
Python and importing and running them in C++:
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_

View File

@ -25,7 +25,7 @@ Here is a simple example:
The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do:
function. To get the second derivative you can do:
.. code-block:: shell
@ -40,7 +40,7 @@ getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. See the following sections for more information on :ref:`automatic
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
@ -50,7 +50,7 @@ Automatic Differentiation
.. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit
graphs.
graphs.
.. note::
@ -114,7 +114,7 @@ way to do that is the following:
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
@ -132,7 +132,7 @@ way to do that is the following:
Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a
In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that.
@ -161,19 +161,19 @@ A naive way to add the elements from two sets of vectors is with a loop:
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs.
where the vectorized axes should be in the outputs.
Let's time these two different versions:
@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
.. code-block:: shell
>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)
@ -77,12 +77,12 @@ from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which output
general, MLX has limited support for operations for which outputs
*shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`.
In Place Updates
In Place Updates
----------------
In place updates to indexed arrays are possible in MLX. For example:
@ -107,16 +107,6 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])
The first element of ``a`` could be ``4`` or ``5``.
Transformations of functions which use in-place updates are allowed and work as
expected. For example:

View File

@ -1,105 +0,0 @@
:orphan:
.. _usage_launch_distributed:
Launching Distributed Programs
==============================
.. currentmodule:: mlx.core.distributed
Installing the MLX python package provides a helper script ``mlx.launch`` that
can be used to run python scripts distributed on several nodes. It allows
launching using either the MPI backend or the ring backend. See the
:doc:`distributed docs <distributed>` for the different backends.
Usage
-----
The minimal usage example of ``mlx.launch`` is simply
.. code:: shell
mlx.launch --hosts ip1,ip2 my_script.py
or for testing on localhost
.. code:: shell
mlx.launch -n 2 my_script.py
The ``mlx.launch`` command connects to the provided host and launches the input
script on each host. It monitors each of the launched processes and terminates
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
It also takes care of forwarding the output of each remote process to stdout
and stderr respectively.
Providing Hosts
^^^^^^^^^^^^^^^^
Hosts can be provided as command line arguments, like above, but the way that
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
a very simple schema. It is simply a list of objects that define each host via
a hostname to ssh to and a list of IPs to utilize for the communication.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
]
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
with IPs corresponding to the ``en0`` interface.
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^^
In order to be able to launch the script on each host we need to be able to
connect via ssh. Moreover the input script and python binary need to be on each
host and on the same path. A good checklist to debug errors is the following:
* ``ssh hostname`` works without asking for password or host confirmation
* the python binary is available on all hosts at the same path. You can use
``mlx.launch --print-python`` to see what that path is.
* the script you want to run is available on all hosts at the same path
.. _mpi_specifics:
MPI Specifics
-------------
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
* The IPs in the hostfile are ignored
* The ssh connectivity requirement is stronger as every node needs to be able
to connect to every other node
* ``mpirun`` needs to be available on every node at the same path
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
to choose a specific interface for the byte-transfer-layer of MPI we can call
``mlx.launch`` as follows:
.. code:: shell
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
.. _ring_specifics:
Ring Specifics
--------------
The ring backend, which is also the default backend, can be explicitly selected
with the argument ``--backend ring``. The ring backend has some specific
requirements and arguments that are different to MPI:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.

View File

@ -13,12 +13,12 @@ compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed.
MLX uses lazy evaluation because it has some nice features, some of which we
describe below.
describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lazy evaluation lets us record a compute graph without actually doing any
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations.
@ -109,14 +109,14 @@ Here is a concrete example:
An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array.
Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass.

View File

@ -3,11 +3,7 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
Let's convert an array to NumPy and back.
.. code-block:: python
@ -21,13 +17,11 @@ Let's convert an array to NumPy and back.
.. note::
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
buffer format string does not match the dtype V item size 0.``
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
By default, NumPy copies data to a new array. This can be prevented by creating
an array view:
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
.. code-block:: python
@ -37,16 +31,10 @@ an array view:
a_view[0] = 1
print(a[0].item()) # 1
.. note::
A NumPy array view is a normal NumPy array, except that it does not own its memory.
This means writing to the view is reflected in the original array.
NumPy arrays with type ``float64`` will be default converted to MLX arrays
with type ``float32``.
A NumPy array view is a normal NumPy array, except that it does not own its
memory. This means writing to the view is reflected in the original array.
While this is quite powerful to prevent copying arrays, it should be noted that
external changes to the memory of arrays cannot be reflected in gradients.
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
Let's demonstrate this in an example:
@ -64,24 +52,22 @@ Let's demonstrate this in an example:
The function ``f`` indirectly modifies the array ``x`` through a memory view.
However, this modification is not reflected in the gradient, as seen in the
last line outputting ``1.0``, representing the gradient of the sum operation
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
gradient is incorporated. It's important to note that a similar issue arises
during array conversion and copying. For instance, a function defined as
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
representing the gradient of the sum operation alone.
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
It's important to note that a similar issue arises during array conversion and copying.
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
even though no in-place operations on MLX memory are executed.
PyTorch
-------
.. warning::
.. warning::
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.
PyTorch supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python
@ -92,8 +78,7 @@ PyTorch supports the buffer protocol, but it requires an explicit
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())
Conversion from PyTorch tensors back to arrays must be done via intermediate
NumPy arrays with ``numpy()``.
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
JAX
---
@ -111,8 +96,7 @@ JAX fully supports the buffer protocol.
TensorFlow
----------
TensorFlow supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python

View File

@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
and :func:`jvp` for Jacobian-vector products.
Use :func:`value_and_grad` to efficiently compute both a function's output and
gradient with respect to the function's input.
gradient with respect to the function's input.

View File

@ -8,33 +8,33 @@ Saving and Loading Arrays
MLX supports multiple array serialization formats.
.. list-table:: Serialization Formats
:widths: 20 8 25 25
:widths: 20 8 25 25
:header-rows: 1
* - Format
- Extension
* - Format
- Extension
- Function
- Notes
* - NumPy
- ``.npy``
- Notes
* - NumPy
- ``.npy``
- :func:`save`
- Single arrays only
* - NumPy archive
- ``.npz``
* - NumPy archive
- ``.npz``
- :func:`savez` and :func:`savez_compressed`
- Multiple arrays
- Multiple arrays
* - Safetensors
- ``.safetensors``
- ``.safetensors``
- :func:`save_safetensors`
- Multiple arrays
* - GGUF
- ``.gguf``
- Multiple arrays
* - GGUF
- ``.gguf``
- :func:`save_gguf`
- Multiple arrays
The :func:`load` function will load any of the supported serialization
formats. It determines the format from the extensions. The output of
:func:`load` depends on the format.
:func:`load` depends on the format.
Here's an example of saving a single array to a file:
@ -49,7 +49,7 @@ it will be added. You can load the array with:
.. code-block:: shell
>>> mx.load("array.npy")
>>> mx.load("array.npy", a)
array([1], dtype=float32)
Here's an example of saving several arrays to a single file:

View File

@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
In MLX, rather than moving arrays to devices, you specify the device when you
run the operation. Any device can perform any operation on ``a`` and ``b``
without needing to move them from one memory location to another. For example:
without needing to move them from one memory location to another. For example:
.. code-block:: python

View File

@ -1,22 +0,0 @@
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Comment the following two commands only the MLX C++ library is installed and
# set(MLX_ROOT "/path/to/mlx") directly if needed.
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)

View File

@ -1,26 +0,0 @@
## Build and Run
Install MLX with Python:
```bash
pip install mlx>=0.22
```
Build the C++ example:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
Run the C++ example:
```
./build/example
```
which should output:
```
array([2, 4, 6], dtype=int32)
```

View File

@ -1,14 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}

View File

@ -8,5 +8,3 @@ endfunction(build_example)
build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)
build_example(distributed.cpp)

View File

@ -1,22 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
if (!mx::distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = mx::distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
mx::array x = mx::ones({10});
mx::array out = mx::distributed::all_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@ -10,7 +10,7 @@
/**
* An example of linear regression with MLX.
*/
namespace mx = mlx::core;
using namespace mlx::core;
int main() {
int num_features = 100;
@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.01;
// True parameters
auto w_star = mx::random::normal({num_features});
auto w_star = random::normal({num_features});
// The input examples (design matrix)
auto X = mx::random::normal({num_examples, num_features});
auto X = random::normal({num_examples, num_features});
// Noisy labels
auto eps = 1e-2 * mx::random::normal({num_examples});
auto y = mx::matmul(X, w_star) + eps;
auto eps = 1e-2 * random::normal({num_examples});
auto y = matmul(X, w_star) + eps;
// Initialize random parameters
mx::array w = 1e-2 * mx::random::normal({num_features});
array w = 1e-2 * random::normal({num_features});
auto loss_fn = [&](mx::array w) {
auto yhat = mx::matmul(X, w);
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
auto loss_fn = [&](array w) {
auto yhat = matmul(X, w);
return (0.5f / num_examples) * sum(square(yhat - y));
};
auto grad_fn = mx::grad(loss_fn);
auto grad_fn = grad(loss_fn);
auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) {
auto grads = grad_fn(w);
w = w - learning_rate * grads;
mx::eval(w);
auto grad = grad_fn(w);
w = w - learning_rate * grad;
eval(w);
}
auto toc = timer::time();
auto loss = loss_fn(w);
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
<< ", Throughput " << throughput << " (it/s)." << std::endl;

View File

@ -10,7 +10,7 @@
/**
* An example of logistic regression with MLX.
*/
namespace mx = mlx::core;
using namespace mlx::core;
int main() {
int num_features = 100;
@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.1;
// True parameters
auto w_star = mx::random::normal({num_features});
auto w_star = random::normal({num_features});
// The input examples
auto X = mx::random::normal({num_examples, num_features});
auto X = random::normal({num_examples, num_features});
// Labels
auto y = mx::matmul(X, w_star) > 0;
auto y = matmul(X, w_star) > 0;
// Initialize random parameters
mx::array w = 1e-2 * mx::random::normal({num_features});
array w = 1e-2 * random::normal({num_features});
auto loss_fn = [&](mx::array w) {
auto logits = mx::matmul(X, w);
auto loss_fn = [&](array w) {
auto logits = matmul(X, w);
auto scale = (1.0f / num_examples);
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
};
auto grad_fn = mx::grad(loss_fn);
auto grad_fn = grad(loss_fn);
auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) {
auto grads = grad_fn(w);
w = w - learning_rate * grads;
mx::eval(w);
auto grad = grad_fn(w);
w = w - learning_rate * grad;
eval(w);
}
auto toc = timer::time();
auto loss = loss_fn(w);
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
<< throughput << " (it/s)." << std::endl;

View File

@ -1,31 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
mx::metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(mx::Device::gpu);
auto s3 = new_stream(mx::Device::gpu);
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
auto x = mx::add(a, a, s2);
auto y = mx::add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << mx::multiply(x, y) << std::endl;
mx::metal::stop_capture();
}

View File

@ -5,11 +5,11 @@
#include "mlx/mlx.h"
namespace mx = mlx::core;
using namespace mlx::core;
void array_basics() {
// Make a scalar array:
mx::array x(1.0);
array x(1.0);
// Get the value out of it:
auto s = x.item<float>();
@ -29,31 +29,31 @@ void array_basics() {
// The datatype should be float32:
auto dtype = x.dtype();
assert(dtype == mx::float32);
assert(dtype == float32);
// Specify the dtype when constructing the array:
x = mx::array(1, mx::int32);
assert(x.dtype() == mx::int32);
x = array(1, int32);
assert(x.dtype() == int32);
x.item<int>(); // OK
// x.item<float>(); // Undefined!
// Make a multidimensional array:
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
// mlx is row-major by default so the first row of this array
// is [1.0, 2.0] and the second row is [3.0, 4.0]
// Make an array of shape {2, 2} filled with ones:
auto y = mx::ones({2, 2});
auto y = ones({2, 2});
// Pointwise add x and y:
auto z = mx::add(x, y);
auto z = add(x, y);
// Same thing:
z = x + y;
// mlx is lazy by default. At this point `z` only
// has a shape and a type but no actual data:
assert(z.dtype() == mx::float32);
assert(z.dtype() == float32);
assert(z.shape(0) == 2);
assert(z.shape(1) == 2);
@ -63,34 +63,34 @@ void array_basics() {
// and inputs. When `eval` is called on an array (or arrays), the array and
// all of its dependencies are recursively evaluated to produce the result.
// Once an array is evaluated, it has data and is detached from its inputs.
mx::eval(z);
eval(z);
// Of course the array can still be an input to other operations. You can
// even call eval on the array again, this will just be a no-op:
mx::eval(z); // no-op
// Of course the array can still be an input to other operations. You can even
// call eval on the array again, this will just be a no-op:
eval(z); // no-op
// Some functions or methods on arrays implicitly evaluate them. For example
// accessing a value in an array or printing the array implicitly evaluate it:
z = mx::ones({1});
z = ones({1});
z.item<float>(); // implicit evaluation
z = mx::ones({2, 2});
z = ones({2, 2});
std::cout << z << std::endl; // implicit evaluation
}
void automatic_differentiation() {
auto fn = [](mx::array x) { return mx::square(x); };
auto fn = [](array x) { return square(x); };
// Computing the derivative function of a function
auto grad_fn = mx::grad(fn);
auto grad_fn = grad(fn);
// Call grad_fn on the input to get the derivative
auto x = mx::array(1.5);
auto x = array(1.5);
auto dfdx = grad_fn(x);
// dfdx is 2 * x
// Get the second derivative by composing grad with grad
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
// d2fdx2 is 2
auto df2dx2 = grad(grad(fn))(x);
// df2dx2 is 2
}
int main() {

Some files were not shown because too many files have changed in this diff Show More