Compare commits

..

5 Commits

Author SHA1 Message Date
Alex Barron
82a956c1d9 fix test 2024-12-06 10:26:54 -08:00
Alex Barron
769704653a cpu fallback 2024-12-06 01:22:50 -08:00
Alex Barron
c89ddf62b4 add checks 2024-12-06 01:09:00 -08:00
Alex Barron
3507c104a5 add test 2024-12-06 00:45:01 -08:00
Alex Barron
12a4d89a7c working qsdpa 2024-12-06 00:21:05 -08:00
753 changed files with 25890 additions and 87980 deletions

413
.circleci/config.yml Normal file
View File

@@ -0,0 +1,413 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
- run:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
- run:
name: Build CPP only
command: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "15.2.0"
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run:
name: Build example extension
command: |
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
build_release:
parameters:
python_version:
type: string
default: "3.9"
xcode_version:
type: string
default: "15.2.0"
build_env:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install --upgrade setuptools
pip install numpy
pip install twine
pip install build
- run:
name: Install Python package
command: |
source env/bin/activate
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
- store_artifacts:
path: dist/
build_linux_release:
parameters:
python_version:
type: string
default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -1,24 +0,0 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

View File

@@ -1,38 +0,0 @@
name: 'Build Documentation'
description: 'Build documentation'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-linux
- name: Install dependencies
shell: bash
run: |
sudo apt-get install -y doxygen
source .venv/bin/activate
pip install -r docs/requirements.txt
pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: bash
run: tar -cf artifact.tar -C docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

View File

@@ -1,40 +0,0 @@
name: 'Build Linux wheel'
description: 'Build Linux wheel'
inputs:
build-backend:
description: 'Build the backend mlx-cpu package'
type: boolean
required: false
default: false
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
- name: Generate package stubs
shell: bash
run: |
pip install -e ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
- name: Build Python package
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}

View File

@@ -1,41 +0,0 @@
name: 'Build and Test on Linux'
inputs:
toolkit:
description: 'The toolkit to build with'
required: false
default: 'cpu'
runs:
using: "composite"
steps:
- name: Install Python package
id: python_build
shell: sh
env:
DEBUG: 1
CMAKE_ARGS: >-
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
run: |
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
# There is no GPU in arm64 runner, use a common arch.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
# Can not build tests when the built executables can not run.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
fi
pip install --no-build-isolation -e ".[dev]" -v
# Pass the CMAKE_ARGS to following steps.
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build CPP only
shell: bash
run: |
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
cmake --build build -j $(nproc)

View File

@@ -1,34 +0,0 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
build-backend:
description: 'Build the backend mlx-metal package'
type: boolean
required: false
default: false
runs:
using: "composite"
steps:
- name: Build Python package
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
pip install build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w

View File

@@ -1,88 +0,0 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
runs:
using: "composite"
steps:
- name: Install dependencies
env:
DEBUG: 1
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Install tests dependencies
shell: bash -l {0}
run: |
pip install numpy torch tensorflow unittest-xml-reporting
- name: Run Python tests
shell: bash -l {0}
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
shell: bash -l {0}
run: |
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext --inplace
python test.py
- name: Build CPP only
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
shell: bash -l {0}
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
shell: bash -l {0}
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit

View File

@@ -1,87 +0,0 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
toolkit:
description: 'Which toolkit to install'
required: false
default: 'cpu'
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
max-size: 1GB
- name: Install common dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
- name: Setup Python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
pip install setuptools cmake nanobind==2.4.0
echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
- name: Install MPI
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Install CUDA toolkit
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
env:
# Note: the CI machine does not meet CUDA 13's driver requirement.
# Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
}
run: |
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
# Jetson specific. SBSA means Arm Server Base System Architecture.
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y \
libnccl2 libnccl-dev \
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
- name: CUDA packages and driver report
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

View File

@@ -1,24 +0,0 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- uses: conda-incubator/setup-miniconda@v3
with:
miniconda-version: "latest"
python-version: ${{ inputs.python-version }}

View File

@@ -1,69 +0,0 @@
name: 'Run Linux tests'
inputs:
has-gpu:
description: 'Run GPU tests'
required: false
default: false
runs:
using: "composite"
steps:
- name: Run MPI tests
shell: bash
run: |
echo "::group::MPI tests"
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
run: |
echo "::group::Distributed tests"
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::Python tests - CPU"
python -m unittest discover python/tests -v
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::Python tests - GPU"
python -m tests discover python/tests -v
echo "::endgroup::"
- name: Run CPP tests - CPU
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::CPP tests - CPU"
./build/tests/tests
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::CPP tests - GPU"
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
echo "::endgroup::"

View File

@@ -1,6 +0,0 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

View File

@@ -1,27 +0,0 @@
#!/bin/bash
set -ex
# [Setup] Install dependencies inside the container.
dnf update -y
dnf install -y \
blas-devel \
lapack-devel \
openblas-devel \
make \
cmake \
clang \
git
dnf clean all
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
export DEBUG=1
export CMAKE_C_COMPILER=/usr/bin/clang
export CMAKE_CXX_COMPILER=/usr/bin/clang++
mkdir -p build
pushd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
./tests/tests
popd

View File

@@ -1,108 +0,0 @@
name: Build and Test
on:
pull_request:
push:
branches:
- main
# For testing CI without starting a pull request:
- test/*
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:
name: Check Lint
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
name: Linux (cpu, ${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
cuda_build_and_test:
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
if: github.repository == 'ml-explore/mlx'
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
if: matrix.arch == 'x86_64'
with:
has-gpu: true
mac_build_and_test:
name: macOS (${{ matrix.macos-target }})
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
macos-target: ["14.0", "15.0"]
runs-on: [self-hosted, macos]
env:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
build_documentation:
name: Build Documentation
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -1,28 +0,0 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

View File

@@ -1,96 +0,0 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12", "3.13", "3.14"]
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7

20
.github/workflows/pull_request.yml vendored Normal file
View File

@@ -0,0 +1,20 @@
on:
pull_request:
branches:
- main
jobs:
check_lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files

View File

@@ -1,244 +0,0 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
inputs:
dev_release:
description: "Do a dev release or regular release"
required: true
default: "false"
permissions:
contents: read
jobs:
setup:
runs-on: ubuntu-latest
steps:
- name: Set publishing variables
run: echo "Publishing setup complete"
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cpu-${{ matrix.arch }}
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_cuda_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
pattern: mlx-cpu-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/

4
.gitignore vendored
View File

@@ -36,7 +36,6 @@ share/python-wheels/
.installed.cfg .installed.cfg
*.egg *.egg
MANIFEST MANIFEST
uv.lock
# vim # vim
*.swp *.swp
@@ -77,9 +76,6 @@ build/
*.out *.out
*.app *.app
# Debug symbols
*.pdb
# VSCode # VSCode
.vscode/ .vscode/
.DS_Store .DS_Store

View File

@@ -1,22 +1,15 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7 rev: v18.1.8
hooks: hooks:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.1.0 rev: 24.8.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 6.0.0 rev: 5.13.2
hooks: hooks:
- id: isort - id: isort
args: args:

View File

@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals: MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
@@ -19,17 +19,11 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a> </a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software # Third-Party Software
MLX leverages several third-party software, listed here together with MLX leverages several third-party software, listed here together with

View File

@@ -1,24 +1,6 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.24)
if(NOT MLX_VERSION) project(mlx LANGUAGES C CXX)
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
set(_major ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
set(_minor ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION})
endif()
project(
mlx
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
# ----------------------------- Setup ----------------------------- # ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
@@ -26,7 +8,6 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER) set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration ----------------------------- # ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON) option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -35,18 +16,19 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON) option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON) option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.21.0)
endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message( message(
STATUS STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}" "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
@@ -67,18 +49,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.") message(WARNING "Building for x86_64 arch is not officially supported.")
endif() endif()
endif() endif()
else() else()
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
endif() message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
endif() endif()
# ----------------------------- Lib ----------------------------- # ----------------------------- Lib -----------------------------
@@ -89,27 +63,19 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal) set(METAL_LIB "-framework Metal")
find_library(FOUNDATION_LIB Foundation) set(FOUNDATION_LIB "-framework Foundation")
find_library(QUARTZ_LIB QuartzCore) set(QUARTZ_LIB "-framework QuartzCore")
if(METAL_LIB)
message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif() endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
endif() endif()
@@ -117,8 +83,7 @@ if(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0) if(${MACOS_SDK_VERSION} LESS 14.0)
message( message(
@@ -128,12 +93,10 @@ if(MLX_BUILD_METAL)
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}") message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip) https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS >= 14.0")
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif() endif()
execute_process( execute_process(
@@ -142,6 +105,7 @@ if(MLX_BUILD_METAL)
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
target_include_directories( target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
@@ -149,62 +113,16 @@ if(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif() endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
set(MLX_BUILD_GGUF OFF)
# There is no prebuilt OpenBLAS distribution for MSVC.
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
endif()
# Windows implementation of dlfcn.h APIs.
FetchContent_Declare(
dlfcn-win32
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.1
EXCLUDE_FROM_ALL)
block()
set(BUILD_SHARED_LIBS OFF)
FetchContent_MakeAvailable(dlfcn-win32)
endblock()
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
target_link_libraries(mlx PRIVATE dl)
endif()
if(MLX_BUILD_CPU) if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate) find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY) if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(MLX_USE_ACCELERATE)
add_compile_definitions(ACCELERATE_NEW_LAPACK) add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code.
FetchContent_Declare(
openblas
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
GIT_TAG v0.3.28
EXCLUDE_FROM_ALL)
set(BUILD_STATIC_LIBS ON) # link statically
set(NOFORTRAN ON) # msvc has no fortran compiler
FetchContent_MakeAvailable(openblas)
target_link_libraries(mlx PRIVATE openblas)
target_include_directories(
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
else() else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
if(${CMAKE_HOST_APPLE}) if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for # The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead. # openblas instead.
@@ -222,7 +140,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES}) target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old # List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas. # version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED) find_package(BLAS REQUIRED)
@@ -235,19 +153,36 @@ if(MLX_BUILD_CPU)
message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES}) target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
if(WIN32)
find_package(dlfcn-win32 REQUIRED)
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
endif()
endif() endif()
else() else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
message(STATUS "Downloading json") find_package(MPI)
FetchContent_Declare( if(MPI_FOUND)
json execute_process(
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) COMMAND zsh "-c" "mpirun --version"
FetchContent_MakeAvailable(json) OUTPUT_VARIABLE MPI_VERSION
target_include_directories( ERROR_QUIET)
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>) if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
@@ -255,19 +190,12 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>) $<INSTALL_INTERFACE:include>)
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare( FetchContent_Declare(
fmt fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1 GIT_TAG 10.2.1
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt) FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)
@@ -279,7 +207,8 @@ if(MLX_BUILD_PYTHON_BINDINGS)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE nanobind_ROOT) OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif() endif()

View File

@@ -17,11 +17,11 @@ possible.
You can also run the formatters manually as follows: You can also run the formatters manually as follows:
```shell ```
clang-format -i file.cpp clang-format -i file.cpp
``` ```
```shell ```
black file.py black file.py
``` ```

View File

@@ -1,6 +1,4 @@
include CMakeLists.txt include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ * recursive-include mlx/ *
include cmake/*
include python/src/* include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561 include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -68,23 +68,18 @@ in the documentation.
## Installation ## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
macOS, run:
```bash **With `pip`**:
```
pip install mlx pip install mlx
``` ```
To install the CUDA backend on Linux, run: **With `conda`**:
```bash
pip install mlx[cuda]
``` ```
conda install -c conda-forge mlx
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
``` ```
Checkout the Checkout the
@@ -110,7 +105,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following MLX useful in your research and wish to cite it, please use the following
BibTex entry: BibTex entry:
```text ```
@software{mlx2023, @software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert}, author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon}, title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -5,35 +5,35 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
namespace mx = mlx::core; using namespace mlx::core;
void time_value_and_grad() { void time_value_and_grad() {
auto x = mx::ones({200, 1000}); auto x = ones({200, 1000});
mx::eval(x); eval(x);
auto fn = [](mx::array x) { auto fn = [](array x) {
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x)); x = log(exp(x));
} }
return mx::sum(x); return sum(x);
}; };
auto grad_fn = mx::grad(fn); auto grad_fn = grad(fn);
auto independent_value_and_grad = [&]() { auto independent_value_and_grad = [&]() {
auto value = fn(x); auto value = fn(x);
auto dfdx = grad_fn(x); auto dfdx = grad_fn(x);
return std::vector<mx::array>{value, dfdx}; return std::vector<array>{value, dfdx};
}; };
TIME(independent_value_and_grad); TIME(independent_value_and_grad);
auto value_and_grad_fn = mx::value_and_grad(fn); auto value_and_grad_fn = value_and_grad(fn);
auto combined_value_and_grad = [&]() { auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x); auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<mx::array>{value, dfdx}; return std::vector<array>{value, dfdx};
}; };
TIME(combined_value_and_grad); TIME(combined_value_and_grad);
} }
int main() { int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl; std::cout << "Benchmarks for " << default_device() << std::endl;
time_value_and_grad(); time_value_and_grad();
} }

View File

@@ -4,21 +4,21 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
namespace mx = mlx::core; using namespace mlx::core;
void time_add_op() { void time_add_op() {
std::vector<int> sizes(1, 1); std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) { for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back()); sizes.push_back(10 * sizes.back());
} }
set_default_device(mx::Device::cpu); set_default_device(Device::cpu);
for (auto size : sizes) { for (auto size : sizes) {
auto a = mx::random::uniform({size}); auto a = random::uniform({size});
auto b = mx::random::uniform({size}); auto b = random::uniform({size});
mx::eval(a, b); eval(a, b);
std::cout << "Size " << size << std::endl; std::cout << "Size " << size << std::endl;
TIMEM("cpu", mx::add, a, b, mx::Device::cpu); TIMEM("cpu", add, a, b, Device::cpu);
TIMEM("gpu", mx::add, a, b, mx::Device::gpu); TIMEM("gpu", add, a, b, Device::gpu);
} }
} }

View File

@@ -1,111 +1,110 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
namespace mx = mlx::core; using namespace mlx::core;
void time_irregular_binary_ops_1D() { void time_irregular_binary_ops_1D() {
auto device = mx::default_device(); auto device = default_device();
int size = 1000000; int size = 1000000;
int step = 2; int step = 2;
auto a = mx::random::uniform({size}); auto a = random::uniform({size});
auto b = mx::random::uniform({size}); auto b = random::uniform({size});
mx::eval(a, b); eval(a, b);
a = slice(a, {0}, {size}, {step}); a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step}); b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", mx::add, a, b, device); TIMEM("1D strided", add, a, b, device);
} }
void time_irregular_binary_ops_2D() { void time_irregular_binary_ops_2D() {
auto device = mx::default_device(); auto device = default_device();
int size = 2048; int size = 2048;
auto a = mx::random::uniform({size, size}); auto a = random::uniform({size, size});
auto b = mx::random::uniform({size, size}); auto b = random::uniform({size, size});
mx::eval(a, b); eval(a, b);
TIMEM("2D regular", mx::add, a, b, device); TIMEM("2D regular", add, a, b, device);
b = mx::transpose(b); b = transpose(b);
mx::eval(b); eval(b);
TIMEM("2D mx::transpose", mx::add, a, b, device); TIMEM("2D transpose", add, a, b, device);
b = mx::random::uniform({size}); b = random::uniform({size});
mx::eval(b); eval(b);
TIMEM("2D broadcast dim 0", mx::add, a, b, device); TIMEM("2D broadcast dim 0", add, a, b, device);
b = mx::reshape(b, {size, 1}); b = reshape(b, {size, 1});
mx::eval(b); eval(b);
TIMEM("2D broadcast dim 1", mx::add, a, b, device); TIMEM("2D broadcast dim 1", add, a, b, device);
} }
void time_irregular_binary_ops_3D() { void time_irregular_binary_ops_3D() {
auto device = mx::default_device(); auto device = default_device();
int d0 = 32; int d0 = 32;
int d1 = 512; int d1 = 512;
int d2 = 512; int d2 = 512;
auto a = mx::random::uniform({d0, d1, d2}); auto a = random::uniform({d0, d1, d2});
auto b = mx::random::uniform({d0, d1, d2}); auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", mx::add, a, b, device); TIMEM("3D regular", add, a, b, device);
b = mx::transpose(b, {0, 2, 1}); b = transpose(b, {0, 2, 1});
TIMEM("3D mx::transpose", mx::add, a, b, device); TIMEM("3D transpose", add, a, b, device);
b = mx::random::uniform({d1, d2}); b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", mx::add, a, b, device); TIMEM("3D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({d0, 1, d2}); b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", mx::add, a, b, device); TIMEM("3D broadcast dim 1", add, a, b, device);
b = mx::random::uniform({d0, d1, 1}); b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", mx::add, a, b, device); TIMEM("3D broadcast dim 2", add, a, b, device);
b = mx::random::uniform({d2}); b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device); TIMEM("3D broadcast dims 0, 1", add, a, b, device);
b = mx::random::uniform({d1, 1}); b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device); TIMEM("3D broadcast dims 0, 2", add, a, b, device);
b = mx::random::uniform({d0, 1, 1}); b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device); TIMEM("3D broadcast dims 1, 2", add, a, b, device);
} }
void time_irregular_binary_ops_4D() { void time_irregular_binary_ops_4D() {
auto device = mx::default_device(); auto device = default_device();
mx::Shape shape = {8, 8, 512, 512}; std::vector<int> shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape); auto a = random::uniform(shape);
auto b = mx::random::uniform(shape); auto b = random::uniform(shape);
TIMEM("4D regular", mx::add, a, b, device); TIMEM("4D regular", add, a, b, device);
b = mx::transpose(b, {0, 1, 3, 2}); b = transpose(b, {0, 1, 3, 2});
TIMEM("4D mx::transpose", mx::add, a, b, device); TIMEM("4D transpose", add, a, b, device);
std::string om = "4D broadcast dims "; std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) { for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1; shape[i] = 1;
b = mx::random::uniform(shape); b = random::uniform(shape);
std::ostringstream msg; std::ostringstream msg;
msg << om << i; msg << om << i;
TIMEM(msg.str(), mx::add, a, b, device); TIMEM(msg.str(), add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) { for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1; shape[j] = 1;
std::ostringstream msg; std::ostringstream msg;
msg << om << i << ", " << j; msg << om << i << ", " << j;
b = mx::random::uniform(shape); b = random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device); TIMEM(msg.str(), add, a, b, device);
shape[j] = a.shape(j); shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) { for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1; shape[k] = 1;
std::ostringstream msg; std::ostringstream msg;
msg << om << i << ", " << j << ", " << k; msg << om << i << ", " << j << ", " << k;
b = mx::random::uniform(shape); b = random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device); TIMEM(msg.str(), add, a, b, device);
shape[k] = a.shape(k); shape[k] = a.shape(k);
} }
} }
@@ -114,83 +113,83 @@ void time_irregular_binary_ops_4D() {
} }
void time_irregular_reshape() { void time_irregular_reshape() {
auto device = mx::default_device(); auto device = default_device();
mx::Shape shape; std::vector<int> shape;
auto reshape_fn = [&shape, device](const mx::array& a) { auto reshape_fn = [&shape, device](const array& a) {
return mx::reshape(a, shape, device); return reshape(a, shape, device);
}; };
int size = 64; int size = 64;
int d = 2 * size; int d = 2 * size;
auto a = mx::random::uniform({d, d, d}); auto a = random::uniform({d, d, d});
shape = {8 * size, size, size}; shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a); TIMEM("3D contiguous", reshape_fn, a);
a = mx::transpose(a); a = transpose(a);
shape = {8 * size, size, size}; shape = {8 * size, size, size};
TIMEM("3D mx::transpose", reshape_fn, a); TIMEM("3D transpose", reshape_fn, a);
a = mx::transpose(a, {1, 2, 0}); a = transpose(a, {1, 2, 0});
shape = {8 * size, size, size}; shape = {8 * size, size, size};
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a); TIMEM("3D transpose dims 1 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d}); a = broadcast_to(random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a); TIMEM("3D broadcast dim 0", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d}); a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a); TIMEM("3D broadcast dim 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d}); a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a); TIMEM("3D broadcast dim 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d}); a = broadcast_to(random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a); TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d}); a = broadcast_to(random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a); TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d}); a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a); TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d}); a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a); TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
} }
void time_irregular_astype_1D() { void time_irregular_astype_1D() {
auto device = mx::default_device(); auto device = default_device();
int size = 1000000; int size = 1000000;
int step = 2; int step = 2;
auto a = mx::random::uniform({size}); auto a = random::uniform({size});
a = slice(a, {0}, {size}, {step}); a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", mx::astype, a, mx::int32, device); TIMEM("1D strided", astype, a, int32, device);
} }
void time_irregular_astype_2D() { void time_irregular_astype_2D() {
auto device = mx::default_device(); auto device = default_device();
int size = 2048; int size = 2048;
mx::Shape shape = {size, size}; std::vector<int> shape = {size, size};
auto a = mx::random::uniform(shape); auto a = random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device); TIMEM("2D regular", astype, a, int32, device);
a = mx::transpose(a); a = transpose(a);
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device); TIMEM("2D transpose", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size}), shape); a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device); TIMEM("2D broadcast dim 0", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape); a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device); TIMEM("2D broadcast dim 1", astype, a, int32, device);
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
if (argc > 1) { if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu"); bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu); set_default_device(use_gpu ? Device::gpu : Device::cpu);
} }
std::cout << "Benchmarks for " << mx::default_device() << std::endl; std::cout << "Benchmarks for " << default_device() << std::endl;
time_irregular_binary_ops_1D(); time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D(); time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D(); time_irregular_binary_ops_3D();

View File

@@ -3,20 +3,20 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
namespace mx = mlx::core; using namespace mlx::core;
void time_creation_ops() { void time_creation_ops() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto shape = {M, N}; auto shape = {M, N};
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); }; auto full_fp32 = [&]() { return full(shape, 3.3f); };
TIME(full_fp32); TIME(full_fp32);
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); }; auto zeros_fp32 = [&]() { return zeros(shape, float32); };
TIME(zeros_fp32); TIME(zeros_fp32);
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); }; auto ones_fp32 = [&]() { return ones(shape, float32); };
TIME(ones_fp32); TIME(ones_fp32);
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); }; auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32); TIME(arange_fp32);
} }
@@ -24,212 +24,194 @@ void time_type_conversions() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto shape = {M, N}; auto shape = {M, N};
auto device = mx::default_device(); auto device = default_device();
auto a = mx::zeros(shape, mx::float32); auto a = zeros(shape, float32);
mx::eval(a); eval(a);
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device); TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device); TIMEM("float32 to uint32", astype, a, uint32, device);
a = mx::zeros(shape, mx::int32); a = zeros(shape, int32);
mx::eval(a); eval(a);
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device); TIMEM("int32 to float32", astype, a, float32, device);
a = mx::zeros(shape, mx::bool_); a = zeros(shape, bool_);
mx::eval(a); eval(a);
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device); TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device); TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device); TIMEM("bool to uint32", astype, a, uint32, device);
} }
void time_random_generation() { void time_random_generation() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); }; auto uniform = [&]() { return random::uniform({M, N}, float32); };
TIME(uniform); TIME(uniform);
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); }; auto normal = [&]() { return random::normal({M, N}, float32); };
TIME(normal); TIME(normal);
} }
void time_unary_ops() { void time_unary_ops() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto device = mx::default_device(); auto device = default_device();
auto a = mx::random::normal({M, N}); auto a = random::normal({M, N});
mx::eval(a); eval(a);
TIME(mlx::core::abs, a, device); TIME(mlx::core::abs, a, device);
TIME(mx::negative, a, device); TIME(negative, a, device);
TIME(mx::sign, a, device); TIME(sign, a, device);
TIME(mx::square, a, device); TIME(square, a, device);
TIME(mlx::core::sqrt, a, device); TIME(mlx::core::sqrt, a, device);
TIME(mx::rsqrt, a, device); TIME(rsqrt, a, device);
TIME(mlx::core::exp, a, device); TIME(mlx::core::exp, a, device);
a = mx::random::uniform({M, N}); a = random::uniform({M, N});
TIME(mlx::core::log, a, device); TIME(mlx::core::log, a, device);
} }
void time_binary_ops() { void time_binary_ops() {
int M = 1000, N = 100, K = 10; int M = 1000, N = 100, K = 10;
auto condition = mx::random::randint(0, 2, {M, N, K}); auto condition = random::randint(0, 2, {M, N, K});
auto a = mx::random::uniform({M, N, K}); auto a = random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K}); auto b = random::uniform({M, N, K});
auto device = mx::default_device(); auto device = default_device();
mx::eval(a, b); eval(a, b);
TIME(mx::add, a, b, device); TIME(add, a, b, device);
TIME(mx::subtract, a, b, device); TIME(subtract, a, b, device);
TIME(mx::multiply, a, b, device); TIME(multiply, a, b, device);
TIME(mx::divide, a, b, device); TIME(divide, a, b, device);
TIME(mx::maximum, a, b, device); TIME(maximum, a, b, device);
TIME(mx::minimum, a, b, device); TIME(minimum, a, b, device);
TIME(mx::where, condition, a, b, device); TIME(where, condition, a, b, device);
condition = mx::array({true}); condition = array({true});
b = mx::random::uniform({1}); b = random::uniform({1});
mx::eval(b); eval(b);
TIMEM("scalar", mx::add, a, b, device); TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", mx::subtract, a, b, device); TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", mx::subtract, b, a, device); TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", mx::multiply, a, b, device); TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", mx::divide, a, b, device); TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", mx::divide, b, a, device); TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", mx::where, condition, a, b, device); TIMEM("scalar-vector", where, condition, a, b, device);
condition = mx::broadcast_to(mx::array({true}), {1000, 100}); condition = broadcast_to(array({true}), {1000, 100});
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100}); a = broadcast_to(random::uniform({1}), {1000, 100});
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100}); b = broadcast_to(random::uniform({1}), {1000, 100});
mx::eval(a, b); eval(a, b);
TIMEM("scalar-scalar broadcast", mx::add, a, b, device); TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device); TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device); TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device); TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device); TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
} }
void time_strided_ops() { void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50; int M = 50, N = 50, O = 50, P = 50;
auto a = mx::random::uniform({M, N, O, P}); auto a = random::uniform({M, N, O, P});
auto b = mx::random::uniform({M, N, O, P}); auto b = random::uniform({M, N, O, P});
auto device = mx::default_device(); auto device = default_device();
mx::eval(a, b); eval(a, b);
TIMEM("non-strided", mx::add, a, b, device); TIMEM("non-strided", add, a, b, device);
a = mx::transpose(a, {1, 0, 2, 3}); a = transpose(a, {1, 0, 2, 3});
b = mx::transpose(b, {3, 2, 0, 1}); b = transpose(b, {3, 2, 0, 1});
mx::eval(a, b); eval(a, b);
TIMEM("strided", mx::add, a, b, device); TIMEM("strided", add, a, b, device);
} }
void time_comparisons() { void time_comparisons() {
int M = 1000, N = 100, K = 10; int M = 1000, N = 100, K = 10;
auto a = mx::random::uniform({M, N, K}); auto a = random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K}); auto b = random::uniform({M, N, K});
auto device = mx::default_device(); auto device = default_device();
mx::eval(a, b); eval(a, b);
TIME(mx::equal, a, b, device); TIME(equal, a, b, device);
TIME(mx::greater, a, b, device); TIME(greater, a, b, device);
TIME(mx::greater_equal, a, b, device); TIME(greater_equal, a, b, device);
TIME(mx::less, a, b, device); TIME(less, a, b, device);
TIME(mx::less_equal, a, b, device); TIME(less_equal, a, b, device);
} }
void time_matvec() { void time_matvec() {
int M = 2000, N = 200; int M = 2000, N = 200;
auto a = mx::random::uniform({M, N}); auto a = random::uniform({M, N});
auto b = mx::random::uniform({N}); auto b = random::uniform({N});
auto c = mx::random::uniform({M}); auto c = random::uniform({M});
mx::eval(a, b, c); eval(a, b, c);
auto matvec = [&]() { return mx::matmul(a, b); }; auto matvec = [&]() { return matmul(a, b); };
TIME(matvec); TIME(matvec);
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); }; auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
TIME(matvec_transpose); TIME(matvec_transpose);
} }
void time_matmul() { void time_matmul() {
int M = 1000, N = 1000, K = 1000; int M = 1000, N = 1000, K = 1000;
auto a = mx::random::uniform({M, K}); auto a = random::uniform({M, K});
auto b = mx::random::uniform({K, N}); auto b = random::uniform({K, N});
auto device = mx::default_device(); auto device = default_device();
mx::eval(a, b); eval(a, b);
TIME(mx::matmul, a, b, device); TIME(matmul, a, b, device);
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); }; auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
TIME(transpose_matmul); TIME(transpose_matmul);
} }
void time_reductions() { void time_reductions() {
auto a = mx::random::normal({10000, 1000}); auto a = random::normal({10000, 1000});
mx::eval(a); eval(a);
auto sum_all = [&a]() { return mx::sum(a, false); }; auto sum_all = [&a]() { return sum(a, false); };
TIME(sum_all); TIME(sum_all);
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); }; auto sum_along_0 = [&a]() { return sum(a, 0, false); };
TIME(sum_along_0); TIME(sum_along_0);
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); }; auto sum_along_1 = [&a]() { return sum(a, 1, false); };
TIME(sum_along_1); TIME(sum_along_1);
auto prod_all = [&a]() { return mx::prod(a, false); }; auto prod_all = [&a]() { return prod(a, false); };
TIME(prod_all); TIME(prod_all);
auto all_true = [&a]() { return mx::all(a, false); }; auto all_true = [&a]() { return all(a, false); };
TIME(all_true); TIME(all_true);
auto all_along_0 = [&a]() { return mx::all(a, 0, false); }; auto all_along_0 = [&a]() { return all(a, 0, false); };
TIME(all_along_0); TIME(all_along_0);
auto all_along_1 = [&a]() { return mx::all(a, 1, false); }; auto all_along_1 = [&a]() { return all(a, 1, false); };
TIME(all_along_1); TIME(all_along_1);
auto any_true = [&a]() { return mx::any(a, false); }; auto any_true = [&a]() { return any(a, false); };
TIME(any_true); TIME(any_true);
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); }; auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
TIME(argmin_along_0); TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); }; auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
TIME(argmin_along_1); TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
} }
void time_gather_scatter() { void time_gather_scatter() {
auto a = mx::random::normal({1000, 768}); auto a = random::normal({1000, 768});
mx::eval(a); eval(a);
auto indices = mx::random::randint(0, 1000, {256}); auto indices = random::randint(0, 1000, {256});
mx::eval(indices); eval(indices);
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); }; auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
TIME(embedding_lookup); TIME(embedding_lookup);
indices = mx::random::randint(0, 768 * 1000, {256 * 768}); indices = random::randint(0, 768 * 1000, {256 * 768});
mx::eval(indices); eval(indices);
auto single_element_lookup = [&a, &indices]() { auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
return mx::take(a, indices);
};
TIME(single_element_lookup); TIME(single_element_lookup);
indices = mx::random::randint(0, 1000, {256}); indices = random::randint(0, 1000, {256});
auto updates = mx::random::normal({256, 1, 768}); auto updates = random::normal({256, 1, 768});
mx::eval(indices, updates); eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() { auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0); return scatter(a, indices, updates, 0);
@@ -241,10 +223,10 @@ void time_gather_scatter() {
}; };
TIME(embedding_add); TIME(embedding_add);
a = mx::reshape(a, {-1}); a = reshape(a, {-1});
indices = mx::random::randint(0, 768 * 1000, {768 * 256}); indices = random::randint(0, 768 * 1000, {768 * 256});
updates = mx::random::normal({256 * 768, 1}); updates = random::normal({256 * 768, 1});
mx::eval(a, indices, updates); eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() { auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0); return scatter(a, indices, updates, 0);
@@ -258,21 +240,21 @@ void time_gather_scatter() {
} }
void time_divmod() { void time_divmod() {
auto a = mx::random::normal({1000}); auto a = random::normal({1000});
auto b = mx::random::normal({1000}); auto b = random::normal({1000});
mx::eval({a, b}); eval({a, b});
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); }; auto divmod_fused = [&a, &b]() { return divmod(a, b); };
TIME(divmod_fused); TIME(divmod_fused);
auto divmod_separate = [&a, &b]() { auto divmod_separate = [&a, &b]() {
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)}; return std::vector<array>{floor_divide(a, b), remainder(a, b)};
}; };
TIME(divmod_separate); TIME(divmod_separate);
} }
int main() { int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl; std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops(); time_creation_ops();
time_type_conversions(); time_type_conversions();
time_unary_ops(); time_unary_ops();

View File

@@ -142,7 +142,9 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype) c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4 atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -161,7 +163,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks") parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16", "complex64") dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn") transposes = ("nn", "nt", "tn")
shapes = ( shapes = (
(16, 234, 768, 3072), (16, 234, 768, 3072),
@@ -185,7 +187,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0 diff = gflops_mx / gflops_pt - 1.0
print( print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%" f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
) )
if gflops_pt >= 2.0 * gflops_mx: if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^") print("ATTENTION ^^^^^^^")

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse
import os import os
import subprocess import subprocess
import time import time
@@ -195,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True): for transpose in (False, True):
for dtype in ("float32", "float16", "complex64"): for dtype in ("float32", "float16"):
fig, axs = plt.subplots( fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
) )
@@ -214,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}") fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig( fig.savefig(
os.path.join( os.path.join(
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf" results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
) )
) )
plt.close(fig) plt.close(fig)

View File

@@ -5,7 +5,6 @@ import os
import time import time
import torch import torch
import torch.cuda
import torch.mps import torch.mps
@@ -45,10 +44,8 @@ def bench(f, *args):
def sync_if_needed(x): def sync_if_needed(x):
if x.device == torch.device("mps"): if x.device != torch.device("cpu"):
torch.mps.synchronize() torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad() @torch.no_grad()
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def softmax(axis, x): def softmax(axis, x):
ys = [] ys = []
@@ -351,11 +340,7 @@ if __name__ == "__main__":
args.axis.pop(0) args.axis.pop(0)
torch.set_num_threads(1) torch.set_num_threads(1)
device = "mps" device = "cpu" if args.cpu else "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
types = args.dtype types = args.dtype
if not types: if not types:
@@ -475,8 +460,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu": elif args.benchmark == "selu":
print(bench(selu, x)) print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else: else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.") raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -1,107 +0,0 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
from time import time
import mlx.core as mx import mlx.core as mx
import torch import torch

View File

@@ -1,74 +0,0 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -1,84 +0,0 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -1,7 +1,5 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from time_utils import time_fn from time_utils import time_fn
@@ -12,71 +10,32 @@ def layer_norm(x, w, b, eps):
x = x.astype(mx.float32) x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True) mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True) v = mx.var(x, -1, keepdims=True)
y = (x - mu) * mx.rsqrt(v + eps) return (x - mu) * mx.rsqrt(v + eps) * w + b
if w is not None:
y = y * w
if b is not None:
y = y + b
return y
def time_layer_norm(N, dt): def time_layer_norm():
L = 1024
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum() f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum() f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2)) g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2)) g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, L, N)).astype(dt) x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(N,)).astype(dt) w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(N,)).astype(dt) b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, L, N)).astype(dt) y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y) mx.eval(x, w, b, y)
def layer_norm_loop(f, x, w, b): def layer_norm_loop(g, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
gx, gw, gb = x, w, b gx, gw, gb = x, w, b
for _ in range(32): for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y) gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb return gx, gw, gb
time_fn(layer_norm_grad_loop, g1, x, w, b) time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_grad_loop, g2, x, w, b) time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b) time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b) time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y)
def layer_norm_grad_x_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_grad_x_loop, g1, x)
time_fn(layer_norm_grad_x_loop, g2, x)
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
if __name__ == "__main__": if __name__ == "__main__":
for dt in [mx.float32, mx.float16, mx.bfloat16]: time_layer_norm()
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)

View File

@@ -1,212 +0,0 @@
import math
import os
import subprocess
import time
from copy import copy
from functools import partial
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from matplotlib.ticker import FuncFormatter
RESULTS_DIR = "./results"
if not os.path.isdir(RESULTS_DIR):
os.mkdir(RESULTS_DIR)
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
TORCH_DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
N_WARMUP = 5
N_ITER_BENCH = 50
N_ITER_FUNC = 20
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
D_TYPES = ("float32", "float16")
def _power_of_two_formatter(value, _position):
if value <= 0:
return ""
exponent = int(round(math.log2(value)))
if abs(value - (1 << exponent)) / value > 1e-6:
return f"{value:g}"
return f"$2^{{{exponent}}}$"
def torch_sync():
if TORCH_DEVICE.type == "cuda":
torch.cuda.synchronize()
elif TORCH_DEVICE.type == "mps":
torch.mps.synchronize()
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
outs = []
for _ in range(N_ITER_FUNC):
out = copy(self_arr)
out[mask_arr] = src_arr
outs.append(out)
mx.eval(outs)
return outs
@torch.no_grad()
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
outs = []
for _ in range(N_ITER_FUNC):
out = self_tensor.clone()
out.masked_scatter_(mask_tensor, src_tensor)
outs.append(out)
torch_sync()
return outs
def measure(fn):
for _ in range(N_WARMUP):
fn()
start = time.perf_counter_ns()
for _ in range(N_ITER_BENCH):
fn()
end = time.perf_counter_ns()
return (end - start) * 1e-9
def bytes_touched(length, true_count, item_size):
mask_bytes = length
self_bytes = length * item_size * 2 # read + write
src_bytes = true_count * item_size
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
def build_case(length, density, np_dtype, torch_dtype):
true_count = max(1, int(round(length * density)))
rng = np.random.default_rng()
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
mask_np = np.zeros(length, dtype=bool)
mask_np[:true_count] = True
rng.shuffle(mask_np)
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
self_mlx = mx.array(self_np)
mask_mlx = mx.array(mask_np)
src_mlx = mx.array(src_np)
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
# Correctness check once per configuration
mx_out = mx.array(self_np)
mx_out[mask_mlx] = src_mlx
mx.eval(mx_out)
torch_out = self_torch.clone()
torch_out.masked_scatter_(mask_torch, src_torch)
atol = 5e-3 if np_dtype == np.float16 else 1e-5
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
raise AssertionError("masked_scatter results diverged between MLX and Torch")
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
def bench_case(length, density, dtype):
np_dtype = getattr(np, dtype)
torch_dtype = getattr(torch, dtype)
(
self_mlx,
mask_mlx,
src_mlx,
self_torch,
mask_torch,
src_torch,
true_count,
) = build_case(length, density, np_dtype, torch_dtype)
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
time_torch = measure(
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
)
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
bytes_per_gb = float(1024**3)
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
return time_mlx, time_torch, mlx_gbps, torch_gbps
def plot_density(ax_perf, ax_speedup, density, dtype):
mlx_gbps = []
torch_gbps = []
mlx_times = []
torch_times = []
for length in VECTOR_LENGTHS:
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
mlx_gbps.append(gbps_mlx)
torch_gbps.append(gbps_torch)
mlx_times.append(t_mlx)
torch_times.append(t_torch)
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
ax_perf.set_xscale("log", base=2)
ax_perf.set_xticks(VECTOR_LENGTHS)
formatter = FuncFormatter(_power_of_two_formatter)
ax_perf.xaxis.set_major_formatter(formatter)
ax_perf.set_title(f"density={density:.2f}")
ax_perf.set_ylabel("GB/s")
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
ax_perf.legend()
speedup = np.array(torch_times) / np.array(mlx_times)
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
ax_speedup.set_xscale("log", base=2)
ax_speedup.set_xticks(VECTOR_LENGTHS)
ax_speedup.xaxis.set_major_formatter(formatter)
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
def main():
for dtype in D_TYPES:
fig, axs = plt.subplots(
len(MASK_DENSITIES),
2,
figsize=(10, 12),
layout="constrained",
sharex=True,
)
for i, density in enumerate(MASK_DENSITIES):
plot_density(axs[i][0], axs[i][1], density, dtype)
axs[i][0].set_xlabel("vector length")
axs[i][1].set_xlabel("vector length")
fig.suptitle(
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
)
output_path = os.path.join(
RESULTS_DIR,
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
)
fig.savefig(output_path)
plt.close(fig)
if __name__ == "__main__":
main()

View File

@@ -9,10 +9,7 @@ def rms_norm(x, w, eps):
ot = x.dtype ot = x.dtype
x = x.astype(mx.float32) x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
y = (x * n).astype(ot) return (x * n).astype(ot) * w
if w is not None:
y = y * w
return y
def time_rms_norm(): def time_rms_norm():
@@ -37,27 +34,6 @@ def time_rms_norm():
time_fn(rms_norm_loop, mx.compile(g1), x, w) time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w) time_fn(rms_norm_loop, mx.compile(g2), x, w)
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(rms_norm_loop, g1, x)
time_fn(rms_norm_loop, g2, x)
time_fn(rms_norm_loop, mx.compile(g1), x)
time_fn(rms_norm_loop, mx.compile(g2), x)
if __name__ == "__main__": if __name__ == "__main__":
time_rms_norm() time_rms_norm()

View File

@@ -28,34 +28,11 @@ def bench(f, *args):
return (e - s) * 1e-9 return (e - s) * 1e-9
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): def mlx_sdpa_fused_inner(q, k, v, scale):
np_dtype = getattr(np, dtype) return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
def mlx_ref_attn(q, k, v, scale=1.0, mask=None): def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
q_dtype = q.dtype q_dtype = q.dtype
q = q * mx.array(scale, q_dtype) q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3] n_q_heads = q.shape[-3]
@@ -64,7 +41,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
B = q.shape[0] B = q.shape[0]
L = q.shape[2] L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1: if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
@@ -72,27 +48,10 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
v = mx.expand_dims(v, 2) v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2) scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
if mask is not None: scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else: else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats)) scores = mx.softmax(scores, axis=-1)
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v out = scores @ v
if n_repeats > 1: if n_repeats > 1:
@@ -101,55 +60,74 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
return out return out
def mlx_fused_attn(q, k, v, scale, mask): def mlx_spda_unfused(q, k, v, scale, transpose):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose:
q_t = mx.transpose(q, (0, 2, 1, 3))
k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
q_out = q q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func): for i in range(N_iter_func):
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose) if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out) mx.eval(q_out)
return q_out return q_out
def bench_shape( def mlx_spda_fused(q, k, v, scale, transpose):
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None q_out = q
): if transpose:
q_mx, k_mx, v_mx, scale, mask = prepare_inputs( k = mx.transpose(k, (0, 2, 1, 3))
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
) )
time_mlx_unfused = bench( q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
) v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
time_mlx_fused = bench(
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose) scale = math.sqrt(1.0 / head_dim)
o_mlx_unfused = do_attention(
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
atol = 1e-5 if dtype == "float32" else 2e-4 q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol): time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print( print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
) )
return time_mlx_fused, time_mlx_unfused return time_mlx_fused, time_mlx_unfused
@@ -173,51 +151,39 @@ if __name__ == "__main__":
( 1, 128, 128, 64, 32, 32), ( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32), ( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32), ( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 8), ( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 8), ( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 8), ( 1, 4096, 4096, 64, 32, 32),
) )
shapes_80 = ( shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh) # ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 8), ( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 8), ( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 8), ( 1, 4096, 4096, 80, 32, 32),
) )
shapes_128 = ( shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh) # ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 8), ( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 8), ( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 8), ( 1, 4096, 4096, 128, 32, 32),
) )
# fmt: on # fmt: on
shapes = shapes_64 + shapes_80 + shapes_128 shapes = shapes_64 + shapes_80 + shapes_128
masks = [None, "bool", "causal"] print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
print(
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
)
for dtype in dtypes: for dtype in dtypes:
for transpose in transposes: for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes: for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
for mask_in in masks: np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape( time_mlx_fused, time_mlx_unfused = bench_shape(
B, B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
qsl,
ksl,
head_dim,
n_q_heads,
n_kv_heads,
dtype,
transpose,
mask_in,
) )
diff = time_mlx_unfused / time_mlx_fused - 1.0 diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0 t_str = 1 if transpose else 0
print( print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
) )

View File

@@ -1,95 +1,94 @@
import argparse
import math
import mlx.core as mx import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn from time_utils import time_fn
L = 16384 L = 32768
H = 32 H = 32
H_k = H // 4 H_k = H // 4
D = 128 D = 128
V = 128
dtype = mx.float16 dtype = mx.float16
loops = 10 bits = 8
loops = 20
def upproject(x, w): def attention(q, k, v):
if w is None: for _ in range(loops):
return x
else:
return x @ w.T
def attention(q, k, v, mask=None, w=None):
def _sdpa(q, k, v):
B, Hq, L, D = q.shape B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape _, Hk, S, _ = k.shape
_, _, _, V = v.shape
q = q.reshape(B, Hk, Hq // Hk, L, D) q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :] ke = k[:, :, None, :, :]
v = v[:, :, None, :, :] ve = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3) s = q @ ke.transpose(0, 1, 2, 4, 3)
if mask is not None:
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
s = mx.where(m, s, mx.finfo(s.dtype).min)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v q = p @ ve
return o.reshape(B, Hq, L, V) q = q.reshape(B, Hq, L, D)
for i in range(loops):
q = _sdpa(q, k, v)
q = upproject(q, w)
return q return q
def sdpa(q, k, v, mask=None, w=None): def sdpa(q, k, v):
for i in range(loops): for _ in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
q = upproject(q, w)
return q return q
def time_self_attention_primitives(): def quant_sdpa(q, k, v, bits=4):
mx.random.seed(3) for _ in range(loops):
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) q = mx.fast.quantized_scaled_dot_product_attention(
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) q, *k, *v, scale=1.0, mask=None, bits=bits
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) )
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None return q
mx.eval(q, k, v, w)
time_fn(attention, q, k, v, w=w)
def time_self_attention_sdpa(): def quant_attention(q, k, v, bits=4):
mx.random.seed(3) for _ in range(loops):
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) B, Hq, L, D = q.shape
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) Hk = k[0].shape[1]
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None q = q.reshape((B, Hk, Hq // Hk, L, D))
mx.eval(q, k, v, w) ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
time_fn(sdpa, q, k, v, w=w) ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)
q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits)
q = q.reshape((B, Hq, L, D))
return q
def time_self_attention_sdpa_with_mask(): def time_self_attention_primitives(q, k, v):
mx.random.seed(3) time_fn(attention, q, k, v)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mask = mx.full((L,), True)
mask[L // 2 :] = False
mx.eval(q, k, v, mask, w)
def sdpa_mask(*args):
return sdpa(*args, mask=mask, w=w)
def attention_mask(*args): def time_self_attention_sdpa(q, k, v):
return attention(*args, mask=mask, w=w) time_fn(sdpa, q, k, v)
time_fn(attention_mask, q, k, v)
time_fn(sdpa_mask, q, k, v) def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)
def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v, bits)
if __name__ == "__main__": if __name__ == "__main__":
time_self_attention_sdpa() mx.random.seed(3)
time_self_attention_primitives() q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype)
time_self_attention_sdpa_with_mask() k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
mx.eval(q, k, v)
k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)
k = mx.dequantize(*k_quant, bits=bits)
v = mx.dequantize(*v_quant, bits=bits)
time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)

View File

@@ -51,20 +51,6 @@ def time_maximum():
time_fn(mx.maximum, a, b) time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative(): def time_negative():
a = mx.random.uniform(shape=(10000, 1000)) a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a) mx.eval(a)
@@ -122,8 +108,6 @@ if __name__ == "__main__":
time_add() time_add()
time_matmul() time_matmul()
time_min()
time_max()
time_maximum() time_maximum()
time_exp() time_exp()
time_negative() time_negative()

View File

@@ -1,55 +0,0 @@
import time
import mlx.core as mx
rank = mx.distributed.init().rank()
def timeit(fn, a):
# warmup
for _ in range(5):
mx.eval(fn(a))
its = 10
tic = time.perf_counter()
for _ in range(its):
mx.eval(fn(a))
toc = time.perf_counter()
ms = 1000 * (toc - tic) / its
return ms
def all_reduce_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_sum(x)
x = x - 1
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
def all_gather_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_gather(x)[0]
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All gather: time per iteration {ms:.6f} (ms)")
if __name__ == "__main__":
all_reduce_benchmark()
all_gather_benchmark()

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -1,3 +0,0 @@
# This file does nothing but to suppress the cmake warning: "By not providing
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.

View File

@@ -1,7 +1,5 @@
include(CMakeParseArguments) include(CMakeParseArguments)
# clang format off
#
# ############################################################################## # ##############################################################################
# Build metal library # Build metal library
# #
@@ -11,14 +9,11 @@ include(CMakeParseArguments)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of # Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers) DEBUG: Boolean, if true, enables debug compile options # files (like headers)
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
# #
# clang format on
macro(mlx_build_metallib) macro(mlx_build_metallib)
# Parse args # Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG) set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
@@ -26,11 +21,7 @@ macro(mlx_build_metallib)
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options # Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
# Prepare metallib build command # Prepare metallib build command
add_custom_command( add_custom_command(

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES FULL_PATH_NAMES = YES
RECURSIVE = YES RECURSIVE = YES
GENERATE_HTML = NO GENERATE_HTML = YES
GENERATE_LATEX = NO GENERATE_LATEX = NO
GENERATE_XML = YES GENERATE_XML = YES
XML_PROGRAMLISTING = YES XML_PROGRAMLISTING = YES

View File

@@ -1,5 +1,4 @@
sphinx sphinx
breathe breathe
sphinx-book-theme sphinx-book-theme
sphinx-copybutton
mlx mlx

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

View File

@@ -10,7 +10,7 @@ import mlx.core as mx
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "MLX" project = "MLX"
copyright = "2023, Apple" copyright = "2023, MLX Contributors"
author = "MLX Contributors" author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3]) version = ".".join(mx.__version__.split(".")[:3])
release = version release = version
@@ -18,7 +18,6 @@ release = version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
extensions = [ extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc", "sphinx.ext.autodoc",
"sphinx.ext.autosummary", "sphinx.ext.autosummary",
"sphinx.ext.intersphinx", "sphinx.ext.intersphinx",

View File

@@ -8,12 +8,11 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example Simple Example
-------------- --------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise: Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python .. code-block:: python
def exp_elementwise(a: mx.array):
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
T tmp = inp[elem]; T tmp = inp[elem];
@@ -26,8 +25,6 @@ Let's write a custom kernel that computes ``exp`` elementwise:
output_names=["out"], output_names=["out"],
source=source, source=source,
) )
def exp_elementwise(a: mx.array):
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@@ -42,13 +39,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
b = exp_elementwise(a) b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a)) assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note:: .. note::
Only pass the body of the Metal kernel in ``source``. The function We are only required to pass the body of the Metal kernel in ``source``.
signature is generated automatically.
The full function signature will be generated using: The full function signature will be generated using:
@@ -86,34 +78,29 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>; template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
generated code for debugging purposes.
Using Shape/Strides Using Shape/Strides
------------------- -------------------
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which ``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
is ``True`` by default. This will copy the array inputs if needed This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
before the kernel is launched to ensure that the memory layout is row Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
contiguous. Generally this makes writing the kernel easier, since we don't when indexing.
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are input array ``a`` if any are present in ``source``.
present in ``source``. We can then use MLX's built in indexing utils to fetch We can then use MLX's built in indexing utils to fetch the right elements for each thread.
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python .. code-block:: python
def exp_elementwise(a: mx.array):
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
@@ -127,11 +114,8 @@ relying on a copy from ``ensure_row_contiguous``:
name="myexp_strided", name="myexp_strided",
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source, source=source
ensure_row_contiguous=False,
) )
def exp_elementwise(a: mx.array):
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@@ -139,6 +123,7 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes=[a.shape], output_shapes=[a.shape],
output_dtypes=[a.dtype], output_dtypes=[a.dtype],
ensure_row_contiguous=False,
) )
return outputs[0] return outputs[0]
@@ -198,13 +183,25 @@ We'll start with the following MLX implementation using standard ops:
return output return output
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel` Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
to write a fast GPU kernel for both the forward and backward passes. to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel: First we'll implement the forward pass as a fused kernel:
.. code-block:: python .. code-block:: python
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
int H = x_shape[1]; int H = x_shape[1];
@@ -254,26 +251,12 @@ First we'll implement the forward pass as a fused kernel:
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
""" """
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="grid_sample", name="grid_sample",
input_names=["x", "grid"], input_names=["x", "grid"],
output_names=["out"], output_names=["out"],
source=source, source=source,
) )
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
outputs = kernel( outputs = kernel(
inputs=[x, grid], inputs=[x, grid],
template=[("T", x.dtype)], template=[("T", x.dtype)],
@@ -298,11 +281,11 @@ On an M1 Max, we see a big performance improvement:
Grid Sample VJP Grid Sample VJP
--------------- ---------------
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
define its custom vjp transform so MLX can differentiate it. its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra :func:`fast.metal_kernel` features: requires a few extra ``mx.fast.metal_kernel`` features:
* ``init_value=0`` * ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
@@ -316,6 +299,14 @@ We can then implement the backwards pass as follows:
.. code-block:: python .. code-block:: python
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
int H = x_shape[1]; int H = x_shape[1];
@@ -415,15 +406,6 @@ We can then implement the backwards pass as follows:
source=source, source=source,
atomic_outputs=True, atomic_outputs=True,
) )
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
# pad the output channels to simd group size # pad the output channels to simd group size
# so that our `simd_sum`s don't overlap. # so that our `simd_sum`s don't overlap.
simdgroup_size = 32 simdgroup_size = 32

View File

@@ -22,12 +22,12 @@ You can do that in MLX directly:
This function performs that operation while leaving the implementation and This function performs that operation while leaving the implementation and
function transformations to MLX. function transformations to MLX.
However, you may want to customize the underlying implementation, perhaps to However you may need to customize the underlying implementation, perhaps to
make it faster. In this tutorial we will go through adding custom extensions. make it faster or for custom differentiation. In this tutorial we will go
It will cover: through adding custom extensions. It will cover:
* The structure of the MLX library. * The structure of the MLX library.
* Implementing a CPU operation. * Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a GPU operation using metal. * Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation. * Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python. * Building a custom extension and binding it to python.
@@ -45,7 +45,7 @@ Operations
Operations are the front-end functions that operate on arrays. They are defined Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them. in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in ``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++: C++:
@@ -55,7 +55,7 @@ C++:
* Scale and sum two vectors element-wise * Scale and sum two vectors element-wise
* z = alpha * x + beta * y * z = alpha * x + beta * y
* *
* Use NumPy-style broadcasting between x and y * Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed * Inputs are upcasted to floats if needed
**/ **/
array axpby( array axpby(
@@ -66,7 +66,7 @@ C++:
StreamOrDevice s = {} // Stream on which to schedule the operation StreamOrDevice s = {} // Stream on which to schedule the operation
); );
The simplest way to implement this is with existing operations: The simplest way to this operation is in terms of existing operations:
.. code-block:: C++ .. code-block:: C++
@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^ ^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create output arrays given input arrays. Further, a defines how to create outputs arrays given a input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function :class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
more concrete: more concrete:
.. code-block:: C++ .. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> vjp( std::vector<array> vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const array& cotan,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
@@ -138,13 +138,13 @@ more concrete:
* representing the vectorized computation and the axis which * representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension. * corresponds to the output vectorized dimension.
*/ */
std::pair<std::vector<array>, std::vector<int>> vmap( virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** The name of primitive. */ /** Print the primitive. */
const char* name() const override { void print(std::ostream& os) override {
return "Axpby"; os << "Axpby";
} }
/** Equivalence check **/ /** Equivalence check **/
@@ -153,6 +153,9 @@ more concrete:
private: private:
float alpha_; float alpha_;
float beta_; float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
}; };
The :class:`Axpby` class derives from the base :class:`Primitive` class. The The :class:`Axpby` class derives from the base :class:`Primitive` class. The
@@ -185,7 +188,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
auto promoted_dtype = promote_types(x.dtype(), y.dtype()); auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y // Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32) auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype ? promoted_dtype
: promote_types(promoted_dtype, float32); : promote_types(promoted_dtype, float32);
@@ -231,9 +234,11 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
Implementing the CPU Back-end Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing :meth:`Axpby::eval_cpu`. Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
The method will go over each element of the output array, find the Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`. point-wise. This is captured in the templated function :meth:`axpby_impl`.
@@ -241,46 +246,36 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
template <typename T> template <typename T>
void axpby_impl( void axpby_impl(
const mx::array& x, const array& x,
const mx::array& y, const array& y,
mx::array& out, array& out,
float alpha_, float alpha_,
float beta_, float beta_) {
mx::Stream stream) { // We only allocate memory when we are ready to fill the output
out.set_data(mx::allocator::malloc(out.nbytes())); // malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays // Collect input and output data pointers
auto& encoder = mx::cpu::get_command_encoder(stream); const T* x_ptr = x.data<T>();
encoder.set_input_array(x); const T* y_ptr = y.data<T>();
encoder.set_input_array(y); T* out_ptr = out.data<T>();
encoder.set_output_array(out);
// Launch the CPU kernel
encoder.dispatch([x_ptr = x.data<T>(),
y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types // Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_); T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_); T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output // Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) { for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y // Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides); auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides); auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided // We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping // (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
} }
});
} }
Our implementation should work for all incoming floating point arrays. Our implementation should work for all incoming floating point arrays.
@@ -289,32 +284,112 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
.. code-block:: C++ .. code-block:: C++
void Axpby::eval_cpu( /** Fall back implementation for evaluation on CPU */
const std::vector<mx::array>& inputs, void Axpby::eval(
std::vector<mx::array>& outputs) { const std::vector<array>& inputs,
const std::vector<array>& outputs) {
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == mx::float32) { if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_, stream()); return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::float16) { } else if (out.dtype() == float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::bfloat16) { } else if (out.dtype() == bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::complex64) { } else if (out.dtype() == complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"Axpby is only supported for floating point types."); "[Axpby] Only supports floating point types.");
} }
} }
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
.. code-block:: C++
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
}
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here. primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Back-end Implementing the GPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -345,8 +420,8 @@ element in the output.
constant const float& alpha [[buffer(3)]], constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]], constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]], constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]], constant const size_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]], constant const size_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]], constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array // Convert linear indices to offsets in array
@@ -363,10 +438,24 @@ each instantiation a unique host name so we can identify it.
.. code-block:: C++ .. code-block:: C++
instantiate_kernel("axpby_general_float32", axpby_general, float) #define instantiate_axpby(type_name, type) \
instantiate_kernel("axpby_general_float16", axpby_general, float16_t) template [[host_name("axpby_general_" #type_name)]] \
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t) [[kernel]] void axpby_general<type>( \
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t) device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
The logic to determine the kernel, set the inputs, resolve the grid dimensions, The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
@@ -391,17 +480,17 @@ below.
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
// Allocate output memory // Allocate output memory
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel // Resolve name of kernel
std::stream kname; std::ostringstream kname;
kname = "axpby_general_" + type_to_name(out); kname << "axpby_" << "general_" << type_to_name(out);
// Load the metal library // Make sure the metal library is available
auto lib = d.get_library("mlx_ext", current_binary_dir()); d.register_library("mlx_ext");
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib); auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -469,7 +558,7 @@ one we just defined:
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can be built with ops // The jvp transform on the primitive can built with ops
// that are scheduled on the same stream as the primitive // that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the // If argnums = {0}, we only push along x in which case the
@@ -481,7 +570,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())}; return {multiply(scale_arr, tangents[0], stream())};
} }
// If argnums = {0, 1}, we take contributions from both // If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
else { else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -735,7 +824,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}") print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c.dtype}")
print(f"c is correct: {mx.all(c == 6.0).item()}") print(f"c correct: {mx.all(c == 6.0).item()}")
Output: Output:
@@ -743,13 +832,13 @@ Output:
c shape: [3, 4] c shape: [3, 4]
c dtype: float32 c dtype: float32
c is correct: True c correctness: True
Results Results
^^^^^^^ ^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined. with the naive :meth:`simple_axpby` we first defined on the CPU.
.. code-block:: python .. code-block:: python
@@ -757,11 +846,13 @@ with the naive :meth:`simple_axpby` we first defined.
from mlx_sample_extensions import axpby from mlx_sample_extensions import axpby
import time import time
mx.set_default_device(mx.cpu)
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y return alpha * x + beta * y
M = 4096 M = 256
N = 4096 N = 512
x = mx.random.normal((M, N)) x = mx.random.normal((M, N))
y = mx.random.normal((M, N)) y = mx.random.normal((M, N))
@@ -772,24 +863,24 @@ with the naive :meth:`simple_axpby` we first defined.
def bench(f): def bench(f):
# Warm up # Warm up
for i in range(5): for i in range(100):
z = f(x, y, alpha, beta) z = f(x, y, alpha, beta)
mx.eval(z) mx.eval(z)
# Timed run # Timed run
s = time.time() s = time.time()
for i in range(100): for i in range(5000):
z = f(x, y, alpha, beta) z = f(x, y, alpha, beta)
mx.eval(z) mx.eval(z)
e = time.time() e = time.time()
return 1000 * (e - s) / 100 return e - s
simple_time = bench(simple_axpby) simple_time = bench(simple_axpby)
custom_time = bench(axpby) custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms") print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
modest improvements right away! modest improvements right away!
This operation is now good to be used to build other operations, in This operation is now good to be used to build other operations, in

View File

@@ -1,121 +0,0 @@
.. _mlx_in_cpp:
Using MLX in C++
================
You can use MLX in a C++ project with CMake.
.. note::
This guide is based one the following `example using MLX in C++
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
First install MLX:
.. code-block:: bash
pip install -U mlx
You can also install the MLX Python package from source or just the C++
library. For more information see the :ref:`documentation on installing MLX
<build_and_install>`.
Next make an example program in ``example.cpp``:
.. code-block:: C++
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}
The next step is to setup a CMake file in ``CMakeLists.txt``:
.. code-block:: cmake
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Depending on how you installed MLX, you may need to tell CMake where to
find it.
If you installed MLX with Python, then add the following to the CMake file:
.. code-block:: cmake
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
If you installed the MLX C++ package to a system path, then CMake should be
able to find it. If you installed it to a non-standard location or CMake can't
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
.. code-block:: cmake
set(MLX_ROOT "/path/to/mlx/")
Next, instruct CMake to find MLX:
.. code-block:: cmake
find_package(MLX CONFIG REQUIRED)
Finally, add the ``example.cpp`` program as an executable and link MLX.
.. code-block:: cmake
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)
You can build the example with:
.. code-block:: bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
And run it with:
.. code-block:: bash
./build/example
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
.. list-table:: Package Variables
:widths: 20 20
:header-rows: 1
* - Variable
- Description
* - MLX_FOUND
- ``True`` if MLX is found
* - MLX_INCLUDE_DIRS
- Include directory
* - MLX_LIBRARIES
- Libraries to link against
* - MLX_CXX_FLAGS
- Additional compiler flags
* - MLX_BUILD_ACCELERATE
- ``True`` if MLX was built with Accelerate
* - MLX_BUILD_METAL
- ``True`` if MLX was built with Metal

View File

@@ -45,7 +45,6 @@ are the CPU and GPU.
usage/numpy usage/numpy
usage/distributed usage/distributed
usage/using_streams usage/using_streams
usage/export
.. toctree:: .. toctree::
:caption: Examples :caption: Examples
@@ -62,7 +61,6 @@ are the CPU and GPU.
python/array python/array
python/data_types python/data_types
python/devices_and_streams python/devices_and_streams
python/export
python/ops python/ops
python/random python/random
python/transforms python/transforms
@@ -70,8 +68,6 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/cuda
python/memory_management
python/nn python/nn
python/optimizers python/optimizers
python/distributed python/distributed
@@ -90,4 +86,3 @@ are the CPU and GPU.
dev/extensions dev/extensions
dev/metal_debugger dev/metal_debugger
dev/custom_metal_kernels dev/custom_metal_kernels
dev/mlx_in_cpp

View File

@@ -1,5 +1,3 @@
.. _build_and_install:
Build and Install Build and Install
================= =================
@@ -13,51 +11,22 @@ silicon computer is
pip install mlx pip install mlx
To install from PyPI your system must meet the following requirements: To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon) - Using an M series chip (Apple silicon)
- Using a native Python >= 3.10 - Using a native Python >= 3.9
- macOS >= 14.0 - macOS >= 13.5
.. note:: .. note::
MLX is only available on devices running macOS >= 14.0 and higher. MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
CUDA
^^^^
MLX has a CUDA backend which you can install with: MLX is also available on conda-forge. To install MLX with conda do:
.. code-block:: shell .. code-block:: shell
pip install mlx[cuda12] conda install conda-forge::mlx
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.10
Troubleshooting Troubleshooting
@@ -84,7 +53,7 @@ Build Requirements
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
- A C++ compiler with C++17 support (e.g. Clang >= 5.0) - A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make`` - `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0 - Xcode >= 15.0 and macOS SDK >= 14.0
.. note:: .. note::
@@ -94,8 +63,6 @@ Build Requirements
Python API Python API
^^^^^^^^^^ ^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_: `its GitHub repo <https://github.com/ml-explore/mlx>`_:
@@ -107,20 +74,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
pip install . CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
For developing, install the package with development dependencies, and use an For developing, install the package with development dependencies, and use an
editable install: editable install:
.. code-block:: shell .. code-block:: shell
pip install -e ".[dev]" CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with: Once the development dependencies are installed, you can build faster with:
.. code-block:: shell .. code-block:: shell
python setup.py build_ext --inplace CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
Run the tests with: Run the tests with:
@@ -138,8 +105,6 @@ IDE:
C++ API C++ API
^^^^^^^ ^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source. Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start Similarly to the python library, to build and install the MLX C++ library start
@@ -218,7 +183,6 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version xcrun -sdk macosx --show-sdk-version
Binary Size Minimization Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
@@ -247,50 +211,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots. Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -19,8 +19,6 @@ Array
array.ndim array.ndim
array.shape array.shape
array.size array.size
array.real
array.imag
array.abs array.abs
array.all array.all
array.any array.any
@@ -40,7 +38,6 @@ Array
array.log10 array.log10
array.log1p array.log1p
array.log2 array.log2
array.logcumsumexp
array.logsumexp array.logsumexp
array.max array.max
array.mean array.mean

View File

@@ -1,9 +0,0 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -51,20 +51,11 @@ The default floating point type is ``float32`` and the default integer type is
* - ``float32`` * - ``float32``
- 4 - 4
- 32-bit float - 32-bit float
* - ``float64``
- 4
- 64-bit double
* - ``complex64`` * - ``complex64``
- 8 - 8
- 64-bit complex float - 64-bit complex float
.. note::
Arrays with type ``float64`` only work with CPU operations. Using
``float64`` arrays on the GPU will result in an exception.
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category. ``dtype`` (or category) is a subtype of another category.
@@ -75,4 +66,3 @@ documentation for more information. Use :func:`issubdtype` to determine if one
Dtype Dtype
DtypeCategory DtypeCategory
issubdtype issubdtype
finfo

View File

@@ -1,14 +0,0 @@
.. _export:
Export Functions
================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
export_function
import_function
exporter
export_to_dot

View File

@@ -13,4 +13,3 @@ Fast
rope rope
scaled_dot_product_attention scaled_dot_product_attention
metal_kernel metal_kernel
cuda_kernel

View File

@@ -20,5 +20,3 @@ FFT
irfft2 irfft2
rfftn rfftn
irfftn irfftn
fftshift
ifftshift

View File

@@ -16,12 +16,5 @@ Linear Algebra
cross cross
qr qr
svd svd
eigvals
eig
eigvalsh eigvalsh
eigh eigh
lu
lu_factor
pinv
solve
solve_triangular

View File

@@ -1,16 +0,0 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,5 +8,13 @@ Metal
is_available is_available
device_info device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture start_capture
stop_capture stop_capture

View File

@@ -174,7 +174,6 @@ In detail:
value_and_grad value_and_grad
quantize quantize
average_gradients
.. toctree:: .. toctree::

View File

@@ -27,7 +27,6 @@ simple functions.
mish mish
prelu prelu
relu relu
relu2
relu6 relu6
selu selu
sigmoid sigmoid

View File

@@ -50,7 +50,6 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU2
ReLU6 ReLU6
RNN RNN
RoPE RoPE

View File

@@ -32,16 +32,13 @@ Operations
atleast_2d atleast_2d
atleast_3d atleast_3d
bitwise_and bitwise_and
bitwise_invert
bitwise_or bitwise_or
bitwise_xor bitwise_xor
block_masked_mm block_masked_mm
broadcast_arrays
broadcast_to broadcast_to
ceil ceil
clip clip
concatenate concatenate
contiguous
conj conj
conjugate conjugate
convolve convolve
@@ -92,7 +89,6 @@ Operations
isneginf isneginf
isposinf isposinf
issubdtype issubdtype
kron
left_shift left_shift
less less
less_equal less_equal
@@ -103,7 +99,6 @@ Operations
log10 log10
log1p log1p
logaddexp logaddexp
logcumsumexp
logical_not logical_not
logical_and logical_and
logical_or logical_or
@@ -112,7 +107,6 @@ Operations
max max
maximum maximum
mean mean
median
meshgrid meshgrid
min min
minimum minimum
@@ -150,8 +144,6 @@ Operations
sign sign
sin sin
sinh sinh
slice
slice_update
softmax softmax
sort sort
split split
@@ -176,7 +168,6 @@ Operations
tri tri
tril tril
triu triu
unflatten
var var
view view
where where

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state, destination={}) state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", state) mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors")) state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -18,5 +18,3 @@ Common Optimizers
AdamW AdamW
Adamax Adamax
Lion Lion
MultiOptimizer
Muon

View File

@@ -9,7 +9,6 @@ Transforms
:toctree: _autosummary :toctree: _autosummary
eval eval
async_eval
compile compile
custom_function custom_function
disable_compile disable_compile

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python .. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096)) x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x) timeit(nn.gelu, x)
timeit(mx.compile(gelu), x) timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
def fun(x, y): def fun(x, y):
z = x + y z = x + y
state.append(z) state.append(z)
return mx.exp(z) return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0)) fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)] # Prints [array(3, dtype=float32)]
@@ -421,77 +421,3 @@ the most opportunity to optimize the computation graph:
# Compiling the outer function is good to do as it will likely # Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled # be faster even though the inner functions are compiled
fun = mx.compile(outer) fun = mx.compile(outer)
.. _shapeless_compile:
Shapeless Compilation
---------------------
When the shape of an input to a compiled function changes, the function is
recompiled. You can compile a function once and run it on inputs with
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
case changes to the shapes of the inputs do not cause the function to be
recompiled.
.. code-block:: python
def fun(x, y):
return mx.abs(x + y)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.array(1.0)
y = mx.array(-2.0)
# Firt call compiles the function
print(compiled_fun(x, y))
# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))
Use shapeless compilations carefully. Since compilation is not triggered when
shapes change, any graphs which are conditional on the input shapes will not
work as expected. Shape-dependent computations are common and sometimes subtle
to detect. For example:
.. code-block:: python
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)
The second call to the ``compiled_fun`` fails because of the call to
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
.. code-block:: python
def fun(x):
return x.flatten(0, 1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Ok
out = compiled_fun(x)

View File

@@ -5,35 +5,21 @@ Distributed Communication
.. currentmodule:: mlx.core.distributed .. currentmodule:: mlx.core.distributed
MLX supports distributed communication operations that allow the computational cost MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
of training or inference to be shared across many physical machines. At the provide distributed communication operations that allow the computational cost
moment we support several different communication backends introduced below. of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
.. list-table:: .. note::
:widths: 20 80 A lot of operations may not be supported or not as fast as they should be.
:header-rows: 1 We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
* - Backend
- Description
* - :ref:`MPI <mpi_section>`
- A full featured and mature distributed communications library.
* - :ref:`RING <ring_section>`
- Ring all reduce and all gather over TCP sockets. Always available and
usually faster than MPI.
* - :ref:`JACCL <ring_section>`
- Low latency communication with RDMA over thunderbolt. Necessary for
things like tensor parallelism.
* - :ref:`NCCL <nccl_section>`
- The backend of choice for CUDA environments.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
Getting Started Getting Started
--------------- ---------------
A distributed program in MLX is as simple as: MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
.. code:: python .. code:: python
@@ -44,80 +30,74 @@ A distributed program in MLX is as simple as:
print(world.rank(), x) print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all The program above sums the array ``mx.ones(10)`` across all
distributed processes. However, when this script is run with ``python`` only distributed processes. If simply run with ``python``, however, only one
one process is launched and no distributed communication takes place. Namely, process is launched and no distributed communication takes place.
all operations in ``mx.distributed`` are noops when the distributed group has a
size of one. This property allows us to avoid code that checks if we are in a
distributed setting similar to the one below:
.. code:: python To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
import mlx.core as mx following:
x = ...
world = mx.distributed.init()
# No need for the check we can simply do x = mx.distributed.all_sum(x)
if world.size() > 1:
x = mx.distributed.all_sum(x)
Running Distributed Programs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
Continuing with our initial example we can run it on localhost with 4 processes using
.. code:: shell .. code:: shell
$ mlx.launch -n 4 my_script.py $ mpirun -np 2 python test.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
We can also run it on some remote hosts by providing their IPs (provided that The above launches two processes on the same (local) machine and we can see
the script exists on all hosts and they are reachable by ssh) both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell .. code:: shell
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py $ conda install openmpi
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
Consult the dedicated :doc:`usage guide<launching_distributed>` for more Installing with Homebrew may require specifying the location of ``libmpi.dyld``
information on using ``mlx.launch``. so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
Selecting Backend .. code:: shell
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created. Setting up Remote Hosts
-----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
.. note:: .. note::
After a distributed backend is successfully initialized :func:`init` will For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
return **the same backend** if called without arguments or with backend set to the hostname passed to ssh if the current hostname matches ``*.bar.com``.
``any``.
The following examples aim to clarify the backend initialization logic in MLX: An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
.. code:: python .. code::
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend host1 slots=1
world = mx.distributed.init(backend="mpi") host2 slots=1
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
# Case 2: Initialize any backend When using MLX, it is very likely that you want to use 1 slot per host, ie one
world = mx.distributed.init(backend="any") # equivalent to no arguments process per host. The hostfile also needs to contain the current
world2 = mx.distributed.init() # same as above host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
# Case 3: Initialize both backends at the same time
world_mpi = mx.distributed.init(backend="mpi")
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
.. _training_example:
Training Example Training Example
---------------- ----------------
@@ -161,13 +141,12 @@ everything else remaining the same.
from mlx.utils import tree_map from mlx.utils import tree_map
def all_reduce_grads(grads): def all_reduce_grads(grads):
N = mx.distributed.init().size() N = mx.distributed.init()
if N == 1: if N == 1:
return grads return grads
return tree_map( return tree_map(
lambda x: mx.distributed.all_sum(x) / N, lambda x: mx.distributed.all_sum(x) / N,
grads grads)
)
def step(model, x, y): def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y) loss, grads = loss_grad_fn(model, x, y)
@@ -175,476 +154,13 @@ everything else remaining the same.
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss
Utilizing ``nn.average_gradients`` Tuning All Reduce
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -----------------
Although the code example above works correctly; it performs one communication We are working on improving the performance of all reduce on MLX but for now
per gradient. It is significantly more efficient to aggregate several gradients the two main things one can do to extract the most out of distributed training with MLX are:
together and perform fewer communication steps.
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks 1. Perform a few large reductions instead of many small ones to improve
almost identical to the example above: bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
.. code:: python connections between each host to improve bandwidth
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
.. _ring_section:
Getting Started with Ring
-------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Defining a Ring
^^^^^^^^^^^^^^^
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
Thunderbolt Ring
^^^^^^^^^^^^^^^^
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.
.. _jaccl_section:
Getting Started with RDMA over Thunderbolt
------------------------------------------
Starting from version 26.2 RDMA over thunderbolt is available in MacOS and
enables low-latency communication between Macs with thunderbolt 5. MLX provides
the JACCL backend that uses this functionality to achieve communication latency
an order of magnitude lower than the ring backend.
.. note::
The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective
Communication Library* and it is an obvious pun to Nvidia's NCCL but also
tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt
at Apple.
Enabling RDMA
^^^^^^^^^^^^^
Until the feature matures, enabling RDMA over thunderbolt is slightly more
involved and **cannot** be done remotely even with sudo. In fact, it has to be
done in macOS recovery:
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
2. Open the Terminal by going to Utilities -> Terminal.
3. Run ``rdma_ctl enable``.
4. Reboot.
To verify that you have successfully enabled Thunderbolt RDMA you can run
``ibv_devices`` which should produce something like the following for an M3 Ultra.
.. code-block:: bash
~ % ibv_devices
device node GUID
------ ----------------
rdma_en2 8096a9d9edbaac05
rdma_en3 8196a9d9edbaac05
rdma_en5 8396a9d9edbaac05
rdma_en4 8296a9d9edbaac05
rdma_en6 8496a9d9edbaac05
rdma_en7 8596a9d9edbaac05
Defining a Mesh
^^^^^^^^^^^^^^^
The JACCL backend supports only fully connected topologies. Namely, there needs
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
the following topology visualizations, the left one is valid because there is a
connection from any node to any other node, while for the one on the right M3
Ultra 1 is not connected to M3 Ultra 2.
.. raw:: html
<div style="display: flex; text-align: center; align-items: end; font-size: 80%;">
<div>
<img src="/_static/distributed/m3-ultra-mesh.png" alt="M3 Ultra thunderbolt mesh" style="width: 55%">
<p>Fully connected mesh of four M3 Ultra.</p>
</div>
<div>
<img src="/_static/distributed/m3-ultra-mesh-broken.png" alt="M3 Ultra broken thunderbolt mesh" style="width: 55%">
<p>Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).</p>
</div>
</div>
Similar to the ring backend, the easiest way to use JACCL with MLX is to write
a JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain
- Hostnames to use for launching scripts via ssh
- An IP for rank 0 that is reachable by all nodes
- A list of rdma devices that connect each node to each other node
The following JSON defines the valid 4-node mesh from the image above.
.. code-block:: json
[
{
"ssh": "m3-ultra-1",
"ips": ["123.123.123.1"],
"rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"]
},
{
"ssh": "m3-ultra-2",
"ips": [],
"rdma": ["rdma_en5", null, "rdma_en3", "rdma_en4"]
},
{
"ssh": "m3-ultra-3",
"ips": [],
"rdma": ["rdma_en4", "rdma_en3", null, "rdma_en5"]
},
{
"ssh": "m3-ultra-4",
"ips": [],
"rdma": ["rdma_en3", "rdma_en4", "rdma_en5", null]
}
]
Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
disabling the thunderbolt bridge is still required as well as setting up
isolated local networks for each thunderbolt connection.
All of the above can be done instead via ``mlx.distributed_config``. This helper
script will
- ssh into each node
- extract the thunderbolt connectivity
- check for a valid mesh
- provide the commands to configure each node (or run them if sudo is available)
- generate the hostfile to be used with ``mlx.launch``
Putting it All Together
^^^^^^^^^^^^^^^^^^^^^^^^
For example launching a distributed MLX script that uses JACCL is fairly simple
if the nodes are reachable via ssh and have password-less sudo.
First, connect all the thunderbolt cables. Then we can verify the connections
by using the ``mlx.distributed_config`` script to visualize them.
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --dot | dot -Tpng | open -f -a Preview
After making sure that everything looks right we can auto-configure the nodes
and save the hostfile to ``m3-ultra-jaccl.json`` by running:
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --backend jaccl \
--auto-setup --output m3-ultra-jaccl.json
And now we are ready to run a distributed MLX script such as distributed inference
of a gigantic model using MLX-LM.
.. code-block::
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
/path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-V3.2-8bit --shard
.. note::
Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a
different, faster way of synchronizing between the GPU and the CPU. It is
not specific to the JACCL backend and can be used in all cases where the CPU
and GPU need to collaborate for some computation and is pretty critical for
low-latency communication since the communication is done by the CPU.
.. _nccl_section:
Getting Started with NCCL
-------------------------
MLX on CUDA environments ships with the ability to talk to `NCCL
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
communication library that supports both multi-gpu and multi-node setups.
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
it takes to run a distributed job is
.. code-block::
mlx.launch -n 8 test.py
# perfect for interactive scripts
mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
with the same ease
.. code-block::
mlx.launch --hosts my-cuda-node -n 8 test.py
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
because the cluster scheduler will be the one launching the processes. You can
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
order for the MLX NCCL backend to be initialized correctly.
.. _mpi_section:
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to `MPI
<https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ if it is installed
on the machine. Launching distributed MLX programs that use MPI can be done
with ``mpirun`` as expected. However, in the following examples we will be
using ``mlx.launch --backend mpi`` which takes care of some nuisances such as
setting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld``
shared library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
.. code:: shell
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``. Some environments use a non-standard
library filename that can be specified using the ``MPI_LIBNAME`` environment
variable. This is automatically taken care of by ``mlx.launch`` as well.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
Tuning MPI All Reduce
^^^^^^^^^^^^^^^^^^^^^
.. note::
For faster all reduce consider using the ring backend either with Thunderbolt
connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing ``--mca btl_tcp_links N``.
Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
.. _no_mlx_launch:
Distributed Without ``mlx.launch``
----------------------------------
None of the implementations of the distributed backends require launching with
``mlx.launch``. The script simply connects to each host. Starts a process per
rank and sets up the necessary environment variables before delegating to your
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
for more details.
For many use-cases this will be the easiest way to perform distributed
computations in MLX. However, there may be reasons that you cannot or should
not use ``mlx.launch``. A common such case is the use of a scheduler that
starts all the processes for you on machines undetermined at the time of
scheduling the job.
Below we list the environment variables required to use each backend.
Ring
^^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
ports for each rank to listen to, something like the following:
.. code-block:: json
[
["123.123.1.1:5000", "123.123.1.2:5000"],
["123.123.2.1:5000", "123.123.2.2:5000"],
["123.123.3.1:5000", "123.123.3.2:5000"],
["123.123.4.1:5000", "123.123.4.2:5000"]
]
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
from the distributed backend.
JACCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
to all the other ranks connect to in order to establish the RDMA connections.
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
ibverbs device names that connect each node to each other node, something like
the following:
.. code-block:: json
[
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
["rdma_en5", null, "rdma_en3", "rdma_en4"],
["rdma_en4", "rdma_en3", null, "rdma_en5"],
["rdma_en3", "rdma_en4", "rdma_en5", null]
]
NCCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_WORLD_SIZE** should contain the total number of processes that will be
launched.
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
hosts can connect to to establish the NCCL communication.
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
corresponds to this process.
Of course any `other environment variable
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
used by NCCL can be set.
.. _tips_and_tricks:
Tips and Tricks
----------------
This is a small collection of tips to help you utilize better the distributed
communication capabilities of MLX.
- *Test locally first.*
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
scale test on a single node first.
- *Batch your communication.*
As described in the :ref:`training example <training_example>`, performing a
lot of small communication can hurt performance. Copy the approach of
:func:`mlx.nn.average_gradients` to gather many small communications in a
single large one.
- *Visualize the connectivity.*
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
visualize the connnections and make sure that the cables are connected
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
- *Use the debugger.*
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
processes and gathers stdout from all processes. This makes using ``pdb`` a
breeze.

View File

@@ -1,288 +0,0 @@
.. _export_usage:
Exporting Functions
===================
.. currentmodule:: mlx.core
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)
To export a function, provide sample input arrays that the function
can be called with. The data doesn't matter, but the shapes and types of the
arrays do. In the above example we exported ``fun`` with two ``float32``
scalar arrays. We can then import the function and run it:
.. code-block:: python
add_fun = mx.import_function("add.mlxfn")
out, = add_fun(mx.array(1.0), mx.array(2.0))
# Prints: array(3, dtype=float32)
print(out)
out, = add_fun(mx.array(1.0), mx.array(3.0))
# Prints: array(4, dtype=float32)
print(out)
# Raises an exception
add_fun(mx.array(1), mx.array(3.0))
# Raises an exception
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
shapes and types of the inputs are different than the shapes and types of the
example inputs we exported the function with.
Also notice that even though the original ``fun`` returns a single output
array, the imported function always returns a tuple of one or more arrays.
The inputs to :func:`export_function` and to an imported function can be
specified as variable positional arguments or as a tuple of arrays:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
# Same as above
mx.export_function("add.mlxfn", fun, (x, y))
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y)
# Also ok
out, = imported_fun((x, y))
You can pass example inputs to functions as positional or keyword arguments. If
you use keyword arguments to export the function, then you have to use the same
keyword arguments when calling the imported function.
.. code-block:: python
def fun(x, y):
return x + y
# One argument to fun is positional, the other is a kwarg
mx.export_function("add.mlxfn", fun, x, y=y)
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y=y)
# Also ok
out, = imported_fun((x,), {"y": y})
# Raises since the keyword argument is missing
out, = imported_fun(x, y)
# Raises since the keyword argument has the wrong key
out, = imported_fun(x, z=y)
Exporting Modules
-----------------
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
in the exported function. Here's an example:
.. code-block:: python
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x):
return model(x)
mx.export_function("model.mlxfn", call, mx.zeros(4))
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
parameters are also saved to the ``model.mlxfn`` file.
.. note::
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
If you only want to export the ``Module.__call__`` function without the
parameters, pass them as inputs to the ``call`` wrapper:
.. code-block:: python
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x, **params):
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
Shapeless Exports
-----------------
Just like :func:`compile`, functions can also be exported for dynamically shaped
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
to export a function which can be used for inputs with variable shapes:
.. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")
# Ok
out, = imported_abs(mx.array([-1.0]))
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
``imported_abs`` would raise an exception with a shape mismatch.
Shapeless exporting works the same as shapeless compilation and should be
used carefully. See the :ref:`documentation on shapeless compilation
<shapeless_compile>` for more information.
Exporting Multiple Traces
-------------------------
In some cases, functions build different computation graphs for different
input arguments. A simple way to manage this is to export to a new file with
each set of inputs. This is a fine option in many cases. But it can be
suboptimal if the exported functions have a large amount of duplicate constant
data (for example the parameters of a :obj:`mlx.nn.Module`).
The export API in MLX lets you export multiple traces of the same function to
a single file by creating an exporting context manager with :func:`exporter`:
.. code-block:: python
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
exporter(mx.array(1.0))
exporter(mx.array(1.0), y=mx.array(0.0))
imported_function = mx.import_function("fun.mlxfn")
# Call the function with y=None
out, = imported_function(mx.array(1.0))
print(out)
# Call the function with y specified
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
Transformations with Imported Functions
---------------------------------------
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
on imported functions just like regular Python functions:
.. code-block:: python
def fun(x):
return mx.sin(x)
x = mx.array(0.0)
mx.export_function("sine.mlxfn", fun, x)
imported_fun = mx.import_function("sine.mlxfn")
# Take the derivative of the imported function
dfdx = mx.grad(lambda x: imported_fun(x)[0])
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
Importing Functions in C++
--------------------------
Importing and running functions in C++ is basically the same as importing and
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
setup a simple C++ project that uses MLX as a library.
Next, export a simple function from Python:
.. code-block:: python
def fun(x, y):
return mx.exp(x + y)
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("fun.mlxfn", fun, x, y)
Import and run the function in C++ with only a few lines of code:
.. code-block:: c++
auto fun = mx::import_function("fun.mlxfn");
auto inputs = {mx::array(1.0), mx::array(1.0)};
auto outputs = fun(inputs);
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.
More Examples
-------------
Here are a few more complete examples exporting more complex functions from
Python and importing and running them in C++:
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_

View File

@@ -70,8 +70,7 @@ Differences from NumPy
* Indexing does not perform bounds checking. Indexing out of bounds is * Indexing does not perform bounds checking. Indexing out of bounds is
undefined behavior. undefined behavior.
* Boolean mask based indexing is supported for assignment only (see * Boolean mask based indexing is not yet supported.
:ref:`boolean-mask-assignment`).
The reason for the lack of bounds checking is that exceptions cannot propagate The reason for the lack of bounds checking is that exceptions cannot propagate
from the GPU. Performing bounds checking for array indices before launching the from the GPU. Performing bounds checking for array indices before launching the
@@ -108,28 +107,6 @@ same array:
>>> a >>> a
array([1, 2, 0], dtype=int32) array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])
The first element of ``a`` could be ``4`` or ``5``.
Transformations of functions which use in-place updates are allowed and work as Transformations of functions which use in-place updates are allowed and work as
expected. For example: expected. For example:
@@ -144,51 +121,3 @@ expected. For example:
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx`` In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
and ones elsewhere. and ones elsewhere.
.. _boolean-mask-assignment:
Boolean Mask Assignment
-----------------------
MLX supports boolean indices using NumPy syntax. A mask must already be
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
Other index types are routed through the standard scatter code.
.. code-block:: shell
>>> a = mx.array([1.0, 2.0, 3.0])
>>> mask = mx.array([True, False, True])
>>> updates = mx.array([5.0, 6.0])
>>> a[mask] = updates
>>> a
array([5.0, 2.0, 6.0], dtype=float32)
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
assignments, ``updates`` must provide at least as many elements as there are
``True`` entries in ``mask``.
.. code-block:: shell
>>> a = mx.zeros((2, 3))
>>> mask = mx.array([[True, False, True],
[False, False, True]])
>>> a[mask] = 1.0
>>> a
array([[1.0, 0.0, 1.0],
[0.0, 0.0, 1.0]], dtype=float32)
Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. The only
exception is a scalar boolean mask, which broadcasts to the full array.
- Any axes not covered by the mask are taken in full.
.. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
axes and therefore raise errors.

View File

@@ -1,234 +0,0 @@
:orphan:
.. _usage_launch_distributed:
Launching Distributed Programs
==============================
.. currentmodule:: mlx.core.distributed
Installing the MLX python package provides two utilities to help you configure
your Macs for distributed computation and also launch distributed programs on
multiple nodes or with many processes in a single node. These utilities are aptly named
- ``mlx.launch``
- ``mlx.distributed_config``
See the :doc:`distributed docs <distributed>` for an introduction and
getting-started guides to the various backends.
``mlx.distributed_config``
---------------------------
Unless you are launching distributed jobs locally for development or multi-gpu
CUDA environments, then you have several Macs that you need to configure for
distributed communication with MLX.
``mlx.distributed_config`` aims to automate the process of configuring the
network interfaces (especially for communication over thunderbolt) and also
creating the hostfile to be used with ``mlx.launch``.
We will analyse 3 cases of using ``mlx.distributed_config``
1. RDMA over thunderbolt using JACCL
2. TCP/IP over thunderbolt using the ring backend
3. TCP/IP over ethernet using the ring backend
JACCL
^^^^^^^
After following :ref:`the steps to enable RDMA <jaccl_section>` you can run the
following command to configure the nodes and create the hostfile.
.. code-block::
mlx.distributed_config --verbose --backend jaccl \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \
--auto-setup --output m3-ultra-jaccl.json
Let's walk through the steps that the script takes to configure the nodes.
1. Ssh to all nodes to verify that they are reachable
2. Extract the thunderbolt connectivity. Namely run commands on each node to
calculate which node is connected to which other node.
3. Verify that we have a valid fully connected mesh
4. Check that RDMA is enabled
5. Extract the ethernet IP from interface en0
6. Disable the thunderbolt bridge and set up peer to peer networks for each
thunderbolt cable
7. Write the hostfile
Knowing the above steps allows you to manually configure the nodes but also
debug any configuration issue. For instance changing the Ethernet IP to a
different interface directly in the config is possible (as long as it is
reachable from all nodes).
The ``--auto-setup`` argument requires password-less sudo on each node. If it
isn't available then the configuration script will print commands to be run on
each node.
Ring over thunderbolt
^^^^^^^^^^^^^^^^^^^^^
Setting up a ring backend over thunderbolt only requires changing the
``--backend`` from ``jaccl`` to ``ring``.
The steps are very similar with the main difference being that instead of
verifying that the nodes are fully connected, the script attempts to identify a
ring topology (or multiple rings).
Ring over Ethernet
^^^^^^^^^^^^^^^^^^
Configuring the ring backend over ethernet doesn't require setting up network
interface and as such it simply extracts the ``en0`` IP from each node and
writes the hostfile.
Debugging cable connections
^^^^^^^^^^^^^^^^^^^^^^^^^^^
``mlx.distributed_config`` can help you debug the connectivity of your nodes
over thunderbolt by exporting a graph of the connections.
Running
.. code-block::
mlx.distributed_config --verbose \
--hosts host1,host2,host3,host4 \
--over thunderbolt --dot
will export a `GraphViz <https://graphviz.org>`_ representation of the
connections between the nodes which makes it very easy to figure out which
cable is not connected correctly.
See :ref:`the JACCL section <jaccl_section>` for an example.
``mlx.launch``
--------------
The minimal usage example of ``mlx.launch`` is simply
.. code:: shell
mlx.launch --hosts ip1,ip2 my_script.py
or for testing on localhost
.. code:: shell
mlx.launch -n 2 my_script.py
The ``mlx.launch`` command connects to the provided host and launches the input
script on each host. It monitors each of the launched processes and terminates
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
It also takes care of forwarding the output of each remote process to stdout
and stderr respectively.
Importantly, it also broadcasts stdin to each process which enables interactive
programs to work in distributed mode as well as debugging using the interactive
debugger.
Providing Hosts
^^^^^^^^^^^^^^^^
Hosts can be provided as command line arguments, like above, but the way that
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
a very simple schema. It is simply a list of objects that define each host via
a hostname to ssh to and a list of IPs to utilize for the communication.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
]
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
with IPs corresponding to the ``en0`` interface.
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^^
In order to be able to launch the script on each host we need to be able to
connect via ssh. Moreover the input script and python binary need to be on each
host and on the same path. A good checklist to debug errors is the following:
* ``ssh hostname`` works without asking for password or host confirmation
* the python binary is available on all hosts at the same path. You can use
``mlx.launch --print-python`` to see what that path is.
* the script you want to run is available on all hosts at the same path
If you are launching from a node with a completely different setup than the
nodes that the program will run on, you can specify ``--no-verify-script`` so
that ``mlx.launch`` does not attempt to verify that the executable and script
exist locally before launching the distributed job.
.. _ring_specifics:
Ring Specifics
^^^^^^^^^^^^^^
The :ref:`ring <ring_section>` backend, which is also the default
backend, can be explicitly selected with the argument ``--backend ring``. The
ring backend has some specific requirements and arguments that are different to
other backends:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.
.. _jaccl_specifics:
JACCL Specifics
^^^^^^^^^^^^^^^^
The :ref:`JACCL <jaccl_section>` backend can be selected with the argument
``--backend jaccl``. A hostfile is necessary to launch with this backend
because it needs to contain the RDMA devices connecting each node to each other
node.
NCCL Specifics
^^^^^^^^^^^^^^
The :ref:`NCCL <nccl_section>` backend is the default backend for CUDA
environments. When launching from a Mac to a Linux machine with CUDA then the
backend should be selected using ``--backend nccl``.
The ``--repeat-hosts, -n`` argument should be used to launch multi-node and
multi-gpu jobs. For instance
.. code-block::
mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh
will attempt to launch 16 processes, 8 on each node that will all run
``my-job.sh``.
.. _mpi_specifics:
MPI Specifics
^^^^^^^^^^^^^
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
* The IPs in the hostfile are ignored
* The ssh connectivity requirement is stronger as every node needs to be able
to connect to every other node
* ``mpirun`` needs to be available on every node at the same path
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
to choose a specific interface for the byte-transfer-layer of MPI we can call
``mlx.launch`` as follows:
.. code:: shell
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py

View File

@@ -21,13 +21,11 @@ Let's convert an array to NumPy and back.
.. note:: .. note::
Since NumPy does not support ``bfloat16`` arrays, you will need to convert Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``. ``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
buffer format string does not match the dtype V item size 0.``
By default, NumPy copies data to a new array. This can be prevented by creating By default, NumPy copies data to a new array. This can be prevented by creating an array view:
an array view:
.. code-block:: python .. code-block:: python
@@ -37,16 +35,10 @@ an array view:
a_view[0] = 1 a_view[0] = 1
print(a[0].item()) # 1 print(a[0].item()) # 1
.. note:: A NumPy array view is a normal NumPy array, except that it does not own its memory.
This means writing to the view is reflected in the original array.
NumPy arrays with type ``float64`` will be default converted to MLX arrays While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
with type ``float32``.
A NumPy array view is a normal NumPy array, except that it does not own its
memory. This means writing to the view is reflected in the original array.
While this is quite powerful to prevent copying arrays, it should be noted that
external changes to the memory of arrays cannot be reflected in gradients.
Let's demonstrate this in an example: Let's demonstrate this in an example:
@@ -64,12 +56,11 @@ Let's demonstrate this in an example:
The function ``f`` indirectly modifies the array ``x`` through a memory view. The function ``f`` indirectly modifies the array ``x`` through a memory view.
However, this modification is not reflected in the gradient, as seen in the However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
last line outputting ``1.0``, representing the gradient of the sum operation representing the gradient of the sum operation alone.
alone. The squaring of ``x`` occurs externally to MLX, meaning that no The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
gradient is incorporated. It's important to note that a similar issue arises It's important to note that a similar issue arises during array conversion and copying.
during array conversion and copying. For instance, a function defined as For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
even though no in-place operations on MLX memory are executed. even though no in-place operations on MLX memory are executed.
PyTorch PyTorch
@@ -80,8 +71,7 @@ PyTorch
PyTorch Support for :obj:`memoryview` is experimental and can break for PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now. multi-dimensional arrays. Casting to NumPy first is advised for now.
PyTorch supports the buffer protocol, but it requires an explicit PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
:obj:`memoryview`.
.. code-block:: python .. code-block:: python
@@ -92,8 +82,7 @@ PyTorch supports the buffer protocol, but it requires an explicit
b = torch.tensor(memoryview(a)) b = torch.tensor(memoryview(a))
c = mx.array(b.numpy()) c = mx.array(b.numpy())
Conversion from PyTorch tensors back to arrays must be done via intermediate Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
NumPy arrays with ``numpy()``.
JAX JAX
--- ---
@@ -111,8 +100,7 @@ JAX fully supports the buffer protocol.
TensorFlow TensorFlow
---------- ----------
TensorFlow supports the buffer protocol, but it requires an explicit TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
:obj:`memoryview`.
.. code-block:: python .. code-block:: python

View File

@@ -1,22 +0,0 @@
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Comment the following two commands only the MLX C++ library is installed and
# set(MLX_ROOT "/path/to/mlx") directly if needed.
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)

View File

@@ -1,26 +0,0 @@
## Build and Run
Install MLX with Python:
```bash
pip install mlx>=0.22
```
Build the C++ example:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
Run the C++ example:
```
./build/example
```
which should output:
```
array([2, 4, 6], dtype=int32)
```

View File

@@ -1,14 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}

View File

@@ -4,19 +4,19 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
namespace mx = mlx::core; using namespace mlx::core;
int main() { int main() {
if (!mx::distributed::is_available()) { if (!distributed::is_available()) {
std::cout << "No communication backend found" << std::endl; std::cout << "No communication backend found" << std::endl;
return 1; return 1;
} }
auto global_group = mx::distributed::init(); auto global_group = distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl; std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
mx::array x = mx::ones({10}); array x = ones({10});
mx::array out = mx::distributed::all_sum(x, global_group); array out = distributed::all_sum(x, global_group);
std::cout << out << std::endl; std::cout << out << std::endl;
} }

View File

@@ -10,7 +10,7 @@
/** /**
* An example of linear regression with MLX. * An example of linear regression with MLX.
*/ */
namespace mx = mlx::core; using namespace mlx::core;
int main() { int main() {
int num_features = 100; int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.01; float learning_rate = 0.01;
// True parameters // True parameters
auto w_star = mx::random::normal({num_features}); auto w_star = random::normal({num_features});
// The input examples (design matrix) // The input examples (design matrix)
auto X = mx::random::normal({num_examples, num_features}); auto X = random::normal({num_examples, num_features});
// Noisy labels // Noisy labels
auto eps = 1e-2 * mx::random::normal({num_examples}); auto eps = 1e-2 * random::normal({num_examples});
auto y = mx::matmul(X, w_star) + eps; auto y = matmul(X, w_star) + eps;
// Initialize random parameters // Initialize random parameters
mx::array w = 1e-2 * mx::random::normal({num_features}); array w = 1e-2 * random::normal({num_features});
auto loss_fn = [&](mx::array w) { auto loss_fn = [&](array w) {
auto yhat = mx::matmul(X, w); auto yhat = matmul(X, w);
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y)); return (0.5f / num_examples) * sum(square(yhat - y));
}; };
auto grad_fn = mx::grad(loss_fn); auto grad_fn = grad(loss_fn);
auto tic = timer::time(); auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) { for (int it = 0; it < num_iters; ++it) {
auto grads = grad_fn(w); auto grad = grad_fn(w);
w = w - learning_rate * grads; w = w - learning_rate * grad;
mx::eval(w); eval(w);
} }
auto toc = timer::time(); auto toc = timer::time();
auto loss = loss_fn(w); auto loss = loss_fn(w);
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>()); auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
auto throughput = num_iters / timer::seconds(toc - tic); auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
<< ", Throughput " << throughput << " (it/s)." << std::endl; << ", Throughput " << throughput << " (it/s)." << std::endl;

View File

@@ -10,7 +10,7 @@
/** /**
* An example of logistic regression with MLX. * An example of logistic regression with MLX.
*/ */
namespace mx = mlx::core; using namespace mlx::core;
int main() { int main() {
int num_features = 100; int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.1; float learning_rate = 0.1;
// True parameters // True parameters
auto w_star = mx::random::normal({num_features}); auto w_star = random::normal({num_features});
// The input examples // The input examples
auto X = mx::random::normal({num_examples, num_features}); auto X = random::normal({num_examples, num_features});
// Labels // Labels
auto y = mx::matmul(X, w_star) > 0; auto y = matmul(X, w_star) > 0;
// Initialize random parameters // Initialize random parameters
mx::array w = 1e-2 * mx::random::normal({num_features}); array w = 1e-2 * random::normal({num_features});
auto loss_fn = [&](mx::array w) { auto loss_fn = [&](array w) {
auto logits = mx::matmul(X, w); auto logits = matmul(X, w);
auto scale = (1.0f / num_examples); auto scale = (1.0f / num_examples);
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits); return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
}; };
auto grad_fn = mx::grad(loss_fn); auto grad_fn = grad(loss_fn);
auto tic = timer::time(); auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) { for (int it = 0; it < num_iters; ++it) {
auto grads = grad_fn(w); auto grad = grad_fn(w);
w = w - learning_rate * grads; w = w - learning_rate * grad;
mx::eval(w); eval(w);
} }
auto toc = timer::time(); auto toc = timer::time();
auto loss = loss_fn(w); auto loss = loss_fn(w);
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples; auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
auto throughput = num_iters / timer::seconds(toc - tic); auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
<< throughput << " (it/s)." << std::endl; << throughput << " (it/s)." << std::endl;

View File

@@ -5,27 +5,27 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
namespace mx = mlx::core; using namespace mlx::core;
int main() { int main() {
// To use Metal debugging and profiling: // To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON). // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1. // 2. Run with MTL_CAPTURE_ENABLED=1.
mx::metal::start_capture("mlx_trace.gputrace"); metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices // Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each // zero and one, respectively. This naming matches the label assigned to each
// stream's command queue. // stream's command queue.
auto s2 = new_stream(mx::Device::gpu); auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(mx::Device::gpu); auto s3 = new_stream(Device::gpu);
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2); auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3); auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = mx::add(a, a, s2); auto x = add(a, a, s2);
auto y = mx::add(b, b, s3); auto y = add(b, b, s3);
// The multiply will happen on the default stream. // The multiply will happen on the default stream.
std::cout << mx::multiply(x, y) << std::endl; std::cout << multiply(x, y) << std::endl;
mx::metal::stop_capture(); metal::stop_capture();
} }

View File

@@ -5,11 +5,11 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
namespace mx = mlx::core; using namespace mlx::core;
void array_basics() { void array_basics() {
// Make a scalar array: // Make a scalar array:
mx::array x(1.0); array x(1.0);
// Get the value out of it: // Get the value out of it:
auto s = x.item<float>(); auto s = x.item<float>();
@@ -29,31 +29,31 @@ void array_basics() {
// The datatype should be float32: // The datatype should be float32:
auto dtype = x.dtype(); auto dtype = x.dtype();
assert(dtype == mx::float32); assert(dtype == float32);
// Specify the dtype when constructing the array: // Specify the dtype when constructing the array:
x = mx::array(1, mx::int32); x = array(1, int32);
assert(x.dtype() == mx::int32); assert(x.dtype() == int32);
x.item<int>(); // OK x.item<int>(); // OK
// x.item<float>(); // Undefined! // x.item<float>(); // Undefined!
// Make a multidimensional array: // Make a multidimensional array:
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
// mlx is row-major by default so the first row of this array // mlx is row-major by default so the first row of this array
// is [1.0, 2.0] and the second row is [3.0, 4.0] // is [1.0, 2.0] and the second row is [3.0, 4.0]
// Make an array of shape {2, 2} filled with ones: // Make an array of shape {2, 2} filled with ones:
auto y = mx::ones({2, 2}); auto y = ones({2, 2});
// Pointwise add x and y: // Pointwise add x and y:
auto z = mx::add(x, y); auto z = add(x, y);
// Same thing: // Same thing:
z = x + y; z = x + y;
// mlx is lazy by default. At this point `z` only // mlx is lazy by default. At this point `z` only
// has a shape and a type but no actual data: // has a shape and a type but no actual data:
assert(z.dtype() == mx::float32); assert(z.dtype() == float32);
assert(z.shape(0) == 2); assert(z.shape(0) == 2);
assert(z.shape(1) == 2); assert(z.shape(1) == 2);
@@ -63,33 +63,33 @@ void array_basics() {
// and inputs. When `eval` is called on an array (or arrays), the array and // and inputs. When `eval` is called on an array (or arrays), the array and
// all of its dependencies are recursively evaluated to produce the result. // all of its dependencies are recursively evaluated to produce the result.
// Once an array is evaluated, it has data and is detached from its inputs. // Once an array is evaluated, it has data and is detached from its inputs.
mx::eval(z); eval(z);
// Of course the array can still be an input to other operations. You can // Of course the array can still be an input to other operations. You can even
// even call eval on the array again, this will just be a no-op: // call eval on the array again, this will just be a no-op:
mx::eval(z); // no-op eval(z); // no-op
// Some functions or methods on arrays implicitly evaluate them. For example // Some functions or methods on arrays implicitly evaluate them. For example
// accessing a value in an array or printing the array implicitly evaluate it: // accessing a value in an array or printing the array implicitly evaluate it:
z = mx::ones({1}); z = ones({1});
z.item<float>(); // implicit evaluation z.item<float>(); // implicit evaluation
z = mx::ones({2, 2}); z = ones({2, 2});
std::cout << z << std::endl; // implicit evaluation std::cout << z << std::endl; // implicit evaluation
} }
void automatic_differentiation() { void automatic_differentiation() {
auto fn = [](mx::array x) { return mx::square(x); }; auto fn = [](array x) { return square(x); };
// Computing the derivative function of a function // Computing the derivative function of a function
auto grad_fn = mx::grad(fn); auto grad_fn = grad(fn);
// Call grad_fn on the input to get the derivative // Call grad_fn on the input to get the derivative
auto x = mx::array(1.5); auto x = array(1.5);
auto dfdx = grad_fn(x); auto dfdx = grad_fn(x);
// dfdx is 2 * x // dfdx is 2 * x
// Get the second derivative by composing grad with grad // Get the second derivative by composing grad with grad
auto d2fdx2 = mx::grad(mx::grad(fn))(x); auto d2fdx2 = grad(grad(fn))(x);
// d2fdx2 is 2 // d2fdx2 is 2
} }

View File

@@ -1,22 +0,0 @@
cmake_minimum_required(VERSION 3.27)
project(import_mlx LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(eval_mlp eval_mlp.cpp)
target_link_libraries(eval_mlp PRIVATE mlx)
add_executable(train_mlp train_mlp.cpp)
target_link_libraries(train_mlp PRIVATE mlx)

View File

@@ -1,49 +0,0 @@
## Setup
Install MLX:
```bash
pip install mlx>=0.22
```
Build the C++ examples:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
## Run
### Eval MLP
Run the Python script to export the eval function:
```bash
python eval_mlp.py
```
Then run the C++ program to import and run the function:
```
./build/eval_mlp
```
The Python and C++ programs should output the same result.
### Train MLP
Run the Python script to export the model initialization and training
functions:
```bash
python train_mlp.py
```
Then run the C++ program to import and run the functions:
```
./build/train_mlp
```
The Python and C++ programs should output the same results.

View File

@@ -1,25 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <mlx/mlx.h>
#include <iostream>
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
// Make the input
mx::random::seed(42);
auto example_x = mx::random::uniform({batch_size, input_dim});
// Import the function
auto forward = mx::import_function("eval_mlp.mlxfn");
// Call the imported function
auto out = forward({example_x})[0];
std::cout << out << std::endl;
return 0;
}

View File

@@ -1,52 +0,0 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
if __name__ == "__main__":
batch_size = 8
input_dim = 32
output_dim = 10
# Load the model
mx.random.seed(0) # Seed for params
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
mx.eval(model)
# Note, the model parameters are saved in the export function
def forward(x):
return model(x)
mx.random.seed(42) # Seed for input
example_x = mx.random.uniform(shape=(batch_size, input_dim))
mx.export_function("eval_mlp.mlxfn", forward, example_x)
# Import in Python
imported_forward = mx.import_function("eval_mlp.mlxfn")
expected = forward(example_x)
(out,) = imported_forward(example_x)
assert mx.allclose(expected, out)
print(out)

View File

@@ -1,35 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <mlx/mlx.h>
#include <iostream>
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
int output_dim = 10;
auto state = mx::import_function("init_mlp.mlxfn")({});
// Make the input
mx::random::seed(42);
auto example_X = mx::random::normal({batch_size, input_dim});
auto example_y = mx::random::randint(0, output_dim, {batch_size});
// Import the function
auto step = mx::import_function("train_mlp.mlxfn");
// Call the imported function
for (int it = 0; it < 100; ++it) {
state.insert(state.end(), {example_X, example_y});
state = step(state);
eval(state);
auto loss = state.back();
state.pop_back();
if (it % 10 == 0) {
std::cout << "Loss " << loss.item<float>() << std::endl;
}
}
return 0;
}

View File

@@ -1,76 +0,0 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
if __name__ == "__main__":
batch_size = 8
input_dim = 32
output_dim = 10
def init():
# Seed for the parameter initialization
mx.random.seed(0)
model = MLP(
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
)
optimizer = optim.SGD(learning_rate=1e-1)
optimizer.init(model.parameters())
state = [model.parameters(), optimizer.state]
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
return model, optimizer, tree_structure, state
# Export the model parameter initialization
model, optimizer, tree_structure, state = init()
mx.eval(state)
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
def loss_fn(params, X, y):
model.update(params)
return nn.losses.cross_entropy(model(X), y, reduction="mean")
def step(*inputs):
*state, X, y = inputs
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
optimizer.state = opt_state
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
params = optimizer.apply_gradients(grads, params)
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
return *state, loss
# Make some random data
mx.random.seed(42)
example_X = mx.random.normal(shape=(batch_size, input_dim))
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
# Export one step of SGD
imported_step = mx.import_function("train_mlp.mlxfn")
for it in range(100):
*state, loss = imported_step(*state, example_X, example_y)
if it % 10 == 0:
print(f"Loss {loss.item():.6}")

View File

@@ -10,6 +10,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies ----------------------------- # ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package( find_package(
Python 3.8 Python 3.8
COMPONENTS Interpreter Development.Module COMPONENTS Interpreter Development.Module
@@ -17,15 +18,10 @@ find_package(
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE nanobind_ROOT) OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
# ----------------------------- Extensions ----------------------------- # ----------------------------- Extensions -----------------------------
# Add library # Add library

View File

@@ -1,34 +1,25 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h> #include <cassert>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include "axpby/axpby.h" #include "axpby/axpby.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#endif
#ifdef _METAL_ #ifdef _METAL_
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#endif #endif
namespace my_ext { namespace mlx::core {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation Implementation // Operation Implementation
@@ -41,24 +32,24 @@ std::string current_binary_dir() {
* Follow numpy style broadcasting between x and y * Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed * Inputs are upcasted to floats if needed
**/ **/
mx::array axpby( array axpby(
const mx::array& x, // Input mx::array x const array& x, // Input array x
const mx::array& y, // Input mx::array y const array& y, // Input array y
const float alpha, // Scaling factor for x const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y const float beta, // Scaling factor for y
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) { ) {
// Promote dtypes between x and y as needed // Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype()); auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y // Upcast to float32 for non-floating point inputs x and y
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32) auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype ? promoted_dtype
: promote_types(promoted_dtype, mx::float32); : promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s) // Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = mx::astype(x, out_dtype, s); auto x_casted = astype(x, out_dtype, s);
auto y_casted = mx::astype(y, out_dtype, s); auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s) // Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s); auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
@@ -66,12 +57,12 @@ mx::array axpby(
// Construct the array as the output of the Axpby primitive // Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs // with the broadcasted and upcasted arrays as inputs
return mx::array( return array(
/* const mx::Shape& shape = */ out_shape, /* const std::vector<int>& shape = */ out_shape,
/* mx::Dtype dtype = */ out_dtype, /* Dtype dtype = */ out_dtype,
/* std::shared_ptr<mx::Primitive> primitive = */ /* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta), std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs); /* const std::vector<array>& inputs = */ broadcasted_inputs);
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -80,69 +71,140 @@ mx::array axpby(
template <typename T> template <typename T>
void axpby_impl( void axpby_impl(
const mx::array& x, const array& x,
const mx::array& y, const array& y,
mx::array& out, array& out,
float alpha_, float alpha_,
float beta_, float beta_) {
mx::Stream stream) { // We only allocate memory when we are ready to fill the output
out.set_data(mx::allocator::malloc(out.nbytes())); // malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays // Collect input and output data pointers
auto& encoder = mx::cpu::get_command_encoder(stream); const T* x_ptr = x.data<T>();
encoder.set_input_array(x); const T* y_ptr = y.data<T>();
encoder.set_input_array(y); T* out_ptr = out.data<T>();
encoder.set_output_array(out);
// Launch the CPU kernel
encoder.dispatch([x_ptr = x.data<T>(),
y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types // Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_); T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_); T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output // Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) { for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y // Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides); auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides); auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided // We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping // (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
} }
});
} }
void Axpby::eval_cpu( /** Fall back implementation for evaluation on CPU */
const std::vector<mx::array>& inputs, void Axpby::eval(
std::vector<mx::array>& outputs) { const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == mx::float32) { if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_, stream()); return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::float16) { } else if (out.dtype() == float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::bfloat16) { } else if (out.dtype() == bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::complex64) { } else if (out.dtype() == complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"Axpby is only supported for floating point types."); "Axpby is only supported for floating point types.");
} }
} }
///////////////////////////////////////////////////////////////////////////////
// Primitive Accelerate Backend Implementation
///////////////////////////////////////////////////////////////////////////////
#ifdef ACCELERATE_NEW_LAPACK
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common backend if specializations are not available
eval(inputs, outputs);
}
#else // Accelerate not available
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
eval(inputs, outputs);
}
#endif
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Primitive Metal Backend Implementation // Primitive Metal Backend Implementation
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -151,9 +213,10 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */ /** Evaluate primitive on GPU */
void Axpby::eval_gpu( void Axpby::eval_gpu(
const std::vector<mx::array>& inputs, const std::vector<array>& inputs,
std::vector<mx::array>& outputs) { std::vector<array>& outputs) {
// Prepare inputs // Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
@@ -162,7 +225,7 @@ void Axpby::eval_gpu(
// and each stream carries its device identifiers // and each stream carries its device identifiers
auto& s = stream(); auto& s = stream();
// We get the needed metal device using the stream // We get the needed metal device using the stream
auto& d = mx::metal::device(s.device); auto& d = metal::device(s.device);
// Prepare to specialize based on contiguity // Prepare to specialize based on contiguity
bool contiguous_kernel = bool contiguous_kernel =
@@ -172,24 +235,25 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization // Allocate output memory with strides based on specialization
if (contiguous_kernel) { if (contiguous_kernel) {
out.set_data( out.set_data(
mx::allocator::malloc(x.data_size() * out.itemsize()), allocator::malloc_or_wait(x.data_size() * out.itemsize()),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
} else { } else {
out.set_data(mx::allocator::malloc(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)
std::string kname = "axpby_"; std::ostringstream kname;
kname += (contiguous_kernel ? "contiguous_" : "general_"); kname << "axpby_";
kname += type_to_name(out); kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Load the metal library // Make sure the metal library is available
auto lib = d.get_library("mlx_ext", current_binary_dir()); d.register_library("mlx_ext");
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib); auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -215,7 +279,7 @@ void Axpby::eval_gpu(
if (!contiguous_kernel) { if (!contiguous_kernel) {
compute_encoder.set_vector_bytes(x.shape(), 5); compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6); compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_vector_bytes(y.strides(), 7); compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8); compute_encoder.set_bytes(ndim, 8);
} }
@@ -238,8 +302,8 @@ void Axpby::eval_gpu(
/** Fail evaluation on GPU */ /** Fail evaluation on GPU */
void Axpby::eval_gpu( void Axpby::eval_gpu(
const std::vector<mx::array>& inputs, const std::vector<array>& inputs,
std::vector<mx::array>& out) { std::vector<array>& out) {
throw std::runtime_error("Axpby has no GPU implementation."); throw std::runtime_error("Axpby has no GPU implementation.");
} }
@@ -250,9 +314,9 @@ void Axpby::eval_gpu(
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
std::vector<mx::array> Axpby::jvp( std::vector<array> Axpby::jvp(
const std::vector<mx::array>& primals, const std::vector<array>& primals,
const std::vector<mx::array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops // The jvp transform on the primitive can built with ops
@@ -264,8 +328,8 @@ std::vector<mx::array> Axpby::jvp(
// scaled by beta // scaled by beta
if (argnums.size() > 1) { if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_; auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = mx::array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return {mx::multiply(scale_arr, tangents[0], stream())}; return {multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
@@ -275,24 +339,24 @@ std::vector<mx::array> Axpby::jvp(
} }
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<mx::array> Axpby::vjp( std::vector<array> Axpby::vjp(
const std::vector<mx::array>& primals, const std::vector<array>& primals,
const std::vector<mx::array>& cotangents, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<mx::array>&) { const std::vector<array>&) {
// Reverse mode diff // Reverse mode diff
std::vector<mx::array> vjps; std::vector<array> vjps;
for (auto arg : argnums) { for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_; auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = mx::array(scale, cotangents[0].dtype()); auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream())); vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
} }
return vjps; return vjps;
} }
/** Vectorize primitive along given axis */ /** Vectorize primitive along given axis */
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap( std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<mx::array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation."); throw std::runtime_error("Axpby has no vmap implementation.");
} }
@@ -303,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
} }
} // namespace my_ext } // namespace mlx::core

View File

@@ -1,13 +1,11 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mx = mlx::core; namespace mlx::core {
namespace my_ext {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation // Operation
@@ -20,22 +18,22 @@ namespace my_ext {
* Follow numpy style broadcasting between x and y * Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed * Inputs are upcasted to floats if needed
**/ **/
mx::array axpby( array axpby(
const mx::array& x, // Input array x const array& x, // Input array x
const mx::array& y, // Input array y const array& y, // Input array y
const float alpha, // Scaling factor for x const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y const float beta, // Scaling factor for y
mx::StreamOrDevice s = {} // Stream on which to schedule the operation StreamOrDevice s = {} // Stream on which to schedule the operation
); );
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Primitive // Primitive
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
class Axpby : public mx::Primitive { class Axpby : public Primitive {
public: public:
explicit Axpby(mx::Stream stream, float alpha, float beta) explicit Axpby(Stream stream, float alpha, float beta)
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {}; : Primitive(stream), alpha_(alpha), beta_(beta) {};
/** /**
* A primitive must know how to evaluate itself on the CPU/GPU * A primitive must know how to evaluate itself on the CPU/GPU
@@ -44,25 +42,23 @@ class Axpby : public mx::Primitive {
* To avoid unnecessary allocations, the evaluation function * To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array. * is responsible for allocating space for the array.
*/ */
void eval_cpu( void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
const std::vector<mx::array>& inputs, override;
std::vector<mx::array>& outputs) override; void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_gpu( override;
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
std::vector<mx::array> jvp( std::vector<array> jvp(
const std::vector<mx::array>& primals, const std::vector<array>& primals,
const std::vector<mx::array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) override; const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<mx::array> vjp( std::vector<array> vjp(
const std::vector<mx::array>& primals, const std::vector<array>& primals,
const std::vector<mx::array>& cotangents, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<mx::array>& outputs) override; const std::vector<array>& outputs) override;
/** /**
* The primitive must know how to vectorize itself across * The primitive must know how to vectorize itself across
@@ -70,21 +66,24 @@ class Axpby : public mx::Primitive {
* representing the vectorized computation and the axis which * representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension. * corresponds to the output vectorized dimension.
*/ */
std::pair<std::vector<mx::array>, std::vector<int>> vmap( std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<mx::array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** The name of primitive. */ /** Print the primitive. */
const char* name() const override { void print(std::ostream& os) override {
return "Axpby"; os << "Axpby";
} }
/** Equivalence check **/ /** Equivalence check **/
bool is_equivalent(const mx::Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
float alpha_; float alpha_;
float beta_; float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
}; };
} // namespace my_ext } // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023 Apple Inc.
#include <metal_stdlib> #include <metal_stdlib>
@@ -12,8 +12,8 @@ template <typename T>
constant const float& alpha [[buffer(3)]], constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]], constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]], constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]], constant const size_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]], constant const size_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]], constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
@@ -34,14 +34,29 @@ template <typename T>
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index]; static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
} }
// clang-format off
#define instantiate_axpby(type_name, type) \ #define instantiate_axpby(type_name, type) \
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \ template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
instantiate_kernel( \ axpby_general<type>( \
"axpby_contiguous_" #type_name, axpby_contiguous, type) device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float); instantiate_axpby(float32, float);
instantiate_axpby(float16, half); instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t); instantiate_axpby(complex64, complex64_t);
// clang-format on

View File

@@ -8,12 +8,14 @@
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) { NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX"; m.doc() = "Sample extension for MLX";
m.def( m.def(
"axpby", "axpby",
&my_ext::axpby, &axpby,
"x"_a, "x"_a,
"y"_a, "y"_a,
"alpha"_a, "alpha"_a,

View File

@@ -1,8 +1,8 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools>=42", "setuptools>=42",
"cmake>=3.25", "cmake>=3.24",
"mlx>=0.18.0", "mlx>=0.18.0",
"nanobind==2.4.0", "nanobind==2.2.0",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.24
mlx>=0.21.0 mlx>=0.21.0
nanobind==2.4.0 nanobind==2.2.0

View File

@@ -3,10 +3,8 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4)) a = mx.ones((3, 4))
b = mx.ones((3, 4)) b = mx.ones((3, 4))
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu) c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
print(f"c shape: {c_cpu.shape}") print(f"c shape: {c.shape}")
print(f"c dtype: {c_cpu.dtype}") print(f"c dtype: {c.dtype}")
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}") print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")

View File

@@ -1,11 +1,10 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
@@ -19,48 +18,24 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
if(MSVC)
# Disable some MSVC warnings to speed up compilation.
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
endif()
if(WIN32)
# Export symbols by default to behave like macOS/linux.
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
if(MLX_BUILD_CPU) if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif() endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
endif()
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else() else()
target_sources(mlx add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
endif() endif()

Some files were not shown because too many files have changed in this diff Show More