Compare commits

..

6 Commits

Author SHA1 Message Date
Angelos Katharopoulos
83762691ba Fix four step fft 2025-05-08 14:14:59 -07:00
Angelos Katharopoulos
2a41caa00e Add single kernel bluestein 2025-05-08 13:23:11 -07:00
Angelos Katharopoulos
6593281d25 Refactored four-step 2025-05-08 13:23:11 -07:00
Angelos Katharopoulos
da98e8bce8 Refactored stockham 2025-05-08 13:23:11 -07:00
Angelos Katharopoulos
be57a16a80 More tmp fft changes 2025-05-08 13:23:11 -07:00
Angelos Katharopoulos
1704809f29 Tmp FFT commit 2025-05-08 13:23:11 -07:00
560 changed files with 11343 additions and 52377 deletions

600
.circleci/config.yml Normal file
View File

@@ -0,0 +1,600 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "16.2.0"
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
- run:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build CPP only
command: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build example extension
command: |
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
build_release:
parameters:
python_version:
type: string
default: "3.9"
xcode_version:
type: string
default: "16.2.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install twine
pip install build
- run:
name: Install Python package
command: |
source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
- store_artifacts:
path: dist/
build_linux_release:
parameters:
python_version:
type: string
default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -1,24 +0,0 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

View File

@@ -1,38 +0,0 @@
name: 'Build Documentation'
description: 'Build documentation'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-linux
- name: Install dependencies
shell: bash
run: |
sudo apt-get install -y doxygen
source .venv/bin/activate
pip install -r docs/requirements.txt
pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: bash
run: tar -cf artifact.tar -C docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

View File

@@ -1,40 +0,0 @@
name: 'Build Linux wheel'
description: 'Build Linux wheel'
inputs:
build-backend:
description: 'Build the backend mlx-cpu package'
type: boolean
required: false
default: false
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
- name: Generate package stubs
shell: bash
run: |
pip install -e ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
- name: Build Python package
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}

View File

@@ -1,41 +0,0 @@
name: 'Build and Test on Linux'
inputs:
toolkit:
description: 'The toolkit to build with'
required: false
default: 'cpu'
runs:
using: "composite"
steps:
- name: Install Python package
id: python_build
shell: sh
env:
DEBUG: 1
CMAKE_ARGS: >-
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
run: |
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
# There is no GPU in arm64 runner, use a common arch.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
# Can not build tests when the built executables can not run.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
fi
pip install --no-build-isolation -e ".[dev]" -v
# Pass the CMAKE_ARGS to following steps.
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build CPP only
shell: bash
run: |
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
cmake --build build -j $(nproc)

View File

@@ -1,34 +0,0 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
build-backend:
description: 'Build the backend mlx-metal package'
type: boolean
required: false
default: false
runs:
using: "composite"
steps:
- name: Build Python package
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
pip install build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w

View File

@@ -1,88 +0,0 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
runs:
using: "composite"
steps:
- name: Install dependencies
env:
DEBUG: 1
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Install tests dependencies
shell: bash -l {0}
run: |
pip install numpy torch tensorflow unittest-xml-reporting
- name: Run Python tests
shell: bash -l {0}
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
shell: bash -l {0}
run: |
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext --inplace
python test.py
- name: Build CPP only
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
shell: bash -l {0}
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
shell: bash -l {0}
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit

View File

@@ -1,87 +0,0 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
toolkit:
description: 'Which toolkit to install'
required: false
default: 'cpu'
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
max-size: 1GB
- name: Install common dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
- name: Setup Python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
pip install setuptools cmake nanobind==2.4.0
echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
- name: Install MPI
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Install CUDA toolkit
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
env:
# Note: the CI machine does not meet CUDA 13's driver requirement.
# Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
}
run: |
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
# Jetson specific. SBSA means Arm Server Base System Architecture.
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y \
libnccl2 libnccl-dev \
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
- name: CUDA packages and driver report
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

View File

@@ -1,24 +0,0 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- uses: conda-incubator/setup-miniconda@v3
with:
miniconda-version: "latest"
python-version: ${{ inputs.python-version }}

View File

@@ -1,69 +0,0 @@
name: 'Run Linux tests'
inputs:
has-gpu:
description: 'Run GPU tests'
required: false
default: false
runs:
using: "composite"
steps:
- name: Run MPI tests
shell: bash
run: |
echo "::group::MPI tests"
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
run: |
echo "::group::Distributed tests"
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::Python tests - CPU"
python -m unittest discover python/tests -v
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::Python tests - GPU"
python -m tests discover python/tests -v
echo "::endgroup::"
- name: Run CPP tests - CPU
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::CPP tests - CPU"
./build/tests/tests
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::CPP tests - GPU"
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
echo "::endgroup::"

View File

@@ -1,6 +0,0 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

View File

@@ -1,27 +0,0 @@
#!/bin/bash
set -ex
# [Setup] Install dependencies inside the container.
dnf update -y
dnf install -y \
blas-devel \
lapack-devel \
openblas-devel \
make \
cmake \
clang \
git
dnf clean all
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
export DEBUG=1
export CMAKE_C_COMPILER=/usr/bin/clang
export CMAKE_CXX_COMPILER=/usr/bin/clang++
mkdir -p build
pushd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
./tests/tests
popd

View File

@@ -1,108 +0,0 @@
name: Build and Test
on:
pull_request:
push:
branches:
- main
# For testing CI without starting a pull request:
- test/*
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:
name: Check Lint
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
name: Linux (cpu, ${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
cuda_build_and_test:
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
if: github.repository == 'ml-explore/mlx'
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
if: matrix.arch == 'x86_64'
with:
has-gpu: true
mac_build_and_test:
name: macOS (${{ matrix.macos-target }})
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
macos-target: ["14.0", "15.0"]
runs-on: [self-hosted, macos]
env:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
build_documentation:
name: Build Documentation
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -1,28 +0,0 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

View File

@@ -1,96 +0,0 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12", "3.13", "3.14"]
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7

20
.github/workflows/pull_request.yml vendored Normal file
View File

@@ -0,0 +1,20 @@
on:
pull_request:
branches:
- main
jobs:
check_lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files

View File

@@ -1,244 +0,0 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
inputs:
dev_release:
description: "Do a dev release or regular release"
required: true
default: "false"
permissions:
contents: read
jobs:
setup:
runs-on: ubuntu-latest
steps:
- name: Set publishing variables
run: echo "Publishing setup complete"
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cpu-${{ matrix.arch }}
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_cuda_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
pattern: mlx-cpu-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/

View File

@@ -1,10 +1,4 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
hooks:

View File

@@ -19,17 +19,11 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software
MLX leverages several third-party software, listed here together with

View File

@@ -26,7 +26,6 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -42,9 +41,7 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests -------------------------
message(
@@ -67,18 +64,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
endif()
else()
set(MLX_BUILD_METAL OFF)
endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
# ----------------------------- Lib -----------------------------
@@ -89,26 +78,22 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
if(METAL_LIB)
message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
@@ -117,8 +102,7 @@ if(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -128,12 +112,9 @@ if(MLX_BUILD_METAL)
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS >= 14.0")
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process(
@@ -142,6 +123,7 @@ if(MLX_BUILD_METAL)
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
@@ -149,12 +131,6 @@ if(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
@@ -182,7 +158,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate not found, using default backend.")
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
@@ -255,19 +231,12 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)

View File

@@ -2,7 +2,7 @@
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
[**Examples**](#examples)
[**Examples**](#examples)
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient
to train and deploy models. The design of the framework itself is also
conceptually simple. We intend to make it easy for researchers to extend and
improve MLX with the goal of quickly exploring new ideas.
improve MLX with the goal of quickly exploring new ideas.
The design of MLX is inspired by frameworks like
[NumPy](https://numpy.org/doc/stable/index.html),
@@ -68,30 +68,25 @@ in the documentation.
## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
```bash
**With `pip`**:
```
pip install mlx
```
To install the CUDA backend on Linux, run:
**With `conda`**:
```bash
pip install mlx[cuda]
```
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
conda install -c conda-forge mlx
```
Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.
## Contributing
## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the
@@ -110,7 +105,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following
BibTex entry:
```text
```
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream>
#include <sstream>
@@ -75,7 +74,7 @@ void time_irregular_binary_ops_3D() {
void time_irregular_binary_ops_4D() {
auto device = mx::default_device();
mx::Shape shape = {8, 8, 512, 512};
std::vector<int> shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
@@ -115,7 +114,7 @@ void time_irregular_binary_ops_4D() {
void time_irregular_reshape() {
auto device = mx::default_device();
mx::Shape shape;
std::vector<int> shape;
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
};
@@ -170,7 +169,7 @@ void time_irregular_astype_1D() {
void time_irregular_astype_2D() {
auto device = mx::default_device();
int size = 2048;
mx::Shape shape = {size, size};
std::vector<int> shape = {size, size};
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);

View File

@@ -192,22 +192,6 @@ void time_reductions() {
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
}
void time_gather_scatter() {

View File

@@ -142,7 +142,9 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -161,7 +163,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16", "complex64")
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
@@ -185,7 +187,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
import argparse
import os
import subprocess
import time
@@ -195,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True):
for dtype in ("float32", "float16", "complex64"):
for dtype in ("float32", "float16"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
@@ -214,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
)
)
plt.close(fig)

View File

@@ -5,7 +5,6 @@ import os
import time
import torch
import torch.cuda
import torch.mps
@@ -45,10 +44,8 @@ def bench(f, *args):
def sync_if_needed(x):
if x.device == torch.device("mps"):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad()
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad()
def softmax(axis, x):
ys = []
@@ -351,11 +340,7 @@ if __name__ == "__main__":
args.axis.pop(0)
torch.set_num_threads(1)
device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
device = "cpu" if args.cpu else "mps"
types = args.dtype
if not types:
@@ -475,8 +460,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -1,107 +0,0 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,7 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
@@ -20,63 +18,51 @@ def layer_norm(x, w, b, eps):
return y
def time_layer_norm(N, dt):
L = 1024
def time_layer_norm():
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(f, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
def layer_norm_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_grad_loop, g1, x, w, b)
time_fn(layer_norm_grad_loop, g2, x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_grad_x_loop(g, x):
def layer_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_grad_x_loop, g1, x)
time_fn(layer_norm_grad_x_loop, g2, x)
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
time_fn(layer_norm_loop, g1, x)
time_fn(layer_norm_loop, g2, x)
time_fn(layer_norm_loop, mx.compile(g1), x)
time_fn(layer_norm_loop, mx.compile(g2), x)
if __name__ == "__main__":
for dt in [mx.float32, mx.float16, mx.bfloat16]:
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)
time_layer_norm()

View File

@@ -1,212 +0,0 @@
import math
import os
import subprocess
import time
from copy import copy
from functools import partial
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from matplotlib.ticker import FuncFormatter
RESULTS_DIR = "./results"
if not os.path.isdir(RESULTS_DIR):
os.mkdir(RESULTS_DIR)
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
TORCH_DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
N_WARMUP = 5
N_ITER_BENCH = 50
N_ITER_FUNC = 20
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
D_TYPES = ("float32", "float16")
def _power_of_two_formatter(value, _position):
if value <= 0:
return ""
exponent = int(round(math.log2(value)))
if abs(value - (1 << exponent)) / value > 1e-6:
return f"{value:g}"
return f"$2^{{{exponent}}}$"
def torch_sync():
if TORCH_DEVICE.type == "cuda":
torch.cuda.synchronize()
elif TORCH_DEVICE.type == "mps":
torch.mps.synchronize()
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
outs = []
for _ in range(N_ITER_FUNC):
out = copy(self_arr)
out[mask_arr] = src_arr
outs.append(out)
mx.eval(outs)
return outs
@torch.no_grad()
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
outs = []
for _ in range(N_ITER_FUNC):
out = self_tensor.clone()
out.masked_scatter_(mask_tensor, src_tensor)
outs.append(out)
torch_sync()
return outs
def measure(fn):
for _ in range(N_WARMUP):
fn()
start = time.perf_counter_ns()
for _ in range(N_ITER_BENCH):
fn()
end = time.perf_counter_ns()
return (end - start) * 1e-9
def bytes_touched(length, true_count, item_size):
mask_bytes = length
self_bytes = length * item_size * 2 # read + write
src_bytes = true_count * item_size
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
def build_case(length, density, np_dtype, torch_dtype):
true_count = max(1, int(round(length * density)))
rng = np.random.default_rng()
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
mask_np = np.zeros(length, dtype=bool)
mask_np[:true_count] = True
rng.shuffle(mask_np)
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
self_mlx = mx.array(self_np)
mask_mlx = mx.array(mask_np)
src_mlx = mx.array(src_np)
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
# Correctness check once per configuration
mx_out = mx.array(self_np)
mx_out[mask_mlx] = src_mlx
mx.eval(mx_out)
torch_out = self_torch.clone()
torch_out.masked_scatter_(mask_torch, src_torch)
atol = 5e-3 if np_dtype == np.float16 else 1e-5
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
raise AssertionError("masked_scatter results diverged between MLX and Torch")
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
def bench_case(length, density, dtype):
np_dtype = getattr(np, dtype)
torch_dtype = getattr(torch, dtype)
(
self_mlx,
mask_mlx,
src_mlx,
self_torch,
mask_torch,
src_torch,
true_count,
) = build_case(length, density, np_dtype, torch_dtype)
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
time_torch = measure(
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
)
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
bytes_per_gb = float(1024**3)
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
return time_mlx, time_torch, mlx_gbps, torch_gbps
def plot_density(ax_perf, ax_speedup, density, dtype):
mlx_gbps = []
torch_gbps = []
mlx_times = []
torch_times = []
for length in VECTOR_LENGTHS:
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
mlx_gbps.append(gbps_mlx)
torch_gbps.append(gbps_torch)
mlx_times.append(t_mlx)
torch_times.append(t_torch)
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
ax_perf.set_xscale("log", base=2)
ax_perf.set_xticks(VECTOR_LENGTHS)
formatter = FuncFormatter(_power_of_two_formatter)
ax_perf.xaxis.set_major_formatter(formatter)
ax_perf.set_title(f"density={density:.2f}")
ax_perf.set_ylabel("GB/s")
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
ax_perf.legend()
speedup = np.array(torch_times) / np.array(mlx_times)
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
ax_speedup.set_xscale("log", base=2)
ax_speedup.set_xticks(VECTOR_LENGTHS)
ax_speedup.xaxis.set_major_formatter(formatter)
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
def main():
for dtype in D_TYPES:
fig, axs = plt.subplots(
len(MASK_DENSITIES),
2,
figsize=(10, 12),
layout="constrained",
sharex=True,
)
for i, density in enumerate(MASK_DENSITIES):
plot_density(axs[i][0], axs[i][1], density, dtype)
axs[i][0].set_xlabel("vector length")
axs[i][1].set_xlabel("vector length")
fig.suptitle(
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
)
output_path = os.path.join(
RESULTS_DIR,
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
)
fig.savefig(output_path)
plt.close(fig)
if __name__ == "__main__":
main()

View File

@@ -51,20 +51,6 @@ def time_maximum():
time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -122,8 +108,6 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_min()
time_max()
time_maximum()
time_exp()
time_negative()

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -1,3 +0,0 @@
# This file does nothing but to suppress the cmake warning: "By not providing
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.

View File

@@ -11,14 +11,13 @@ include(CMakeParseArguments)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
# files (like headers)
#
# clang format on
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
@@ -27,10 +26,6 @@ macro(mlx_build_metallib)
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
# Prepare metallib build command
add_custom_command(

View File

@@ -1,5 +1,4 @@
sphinx
breathe
sphinx-book-theme
sphinx-copybutton
mlx

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

View File

@@ -10,7 +10,7 @@ import mlx.core as mx
# -- Project information -----------------------------------------------------
project = "MLX"
copyright = "2023, Apple"
copyright = "2023, MLX Contributors"
author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3])
release = version
@@ -18,7 +18,6 @@ release = version
# -- General configuration ---------------------------------------------------
extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",

View File

@@ -8,26 +8,23 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -42,13 +39,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note::
Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
We are only required to pass the body of the Metal kernel in ``source``.
The full function signature will be generated using:
@@ -86,52 +78,44 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Using Shape/Strides
-------------------
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
is ``True`` by default. This will copy the array inputs if needed
before the kernel is launched to ensure that the memory layout is row
contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
when indexing.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
input array ``a`` if any are present in ``source``.
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source,
ensure_row_contiguous=False,
)
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -139,6 +123,7 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]
@@ -157,139 +142,137 @@ We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
return output
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
@mx.custom_function
def grid_sample(x, grid):
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
int ix_nw = floor(ix);
int iy_nw = floor(iy);
assert D == 2, "Last dim of `grid` must be size 2."
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
@mx.custom_function
def grid_sample(x, grid):
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
@@ -298,11 +281,11 @@ On an M1 Max, we see a big performance improvement:
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
define its custom vjp transform so MLX can differentiate it.
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra :func:`fast.metal_kernel` features:
requires a few extra ``mx.fast.metal_kernel`` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
@@ -316,129 +299,128 @@ We can then implement the backwards pass as follows:
.. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
int gH = grid_shape[1];
int gW = grid_shape[2];
assert D == 2, "Last dim of `grid` must be size 2."
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int gH = grid_shape[1];
int gW = grid_shape[2];
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T gix_mult = W / 2;
T giy_mult = H / 2;
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
T gix_mult = W / 2;
T giy_mult = H / 2;
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
assert D == 2, "Last dim of `grid` must be size 2."
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:

View File

@@ -138,13 +138,13 @@ more concrete:
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** The name of primitive. */
const char* name() const override {
return "Axpby";
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/
@@ -394,14 +394,14 @@ below.
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::stream kname;
kname = "axpby_general_" + type_to_name(out);
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -70,7 +70,6 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/cuda
python/memory_management
python/nn
python/optimizers

View File

@@ -13,51 +13,22 @@ silicon computer is
pip install mlx
To install from PyPI your system must meet the following requirements:
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.10
- macOS >= 14.0
- Using a native Python >= 3.9
- macOS >= 13.5
.. note::
MLX is only available on devices running macOS >= 14.0 and higher.
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
CUDA
^^^^
MLX has a CUDA backend which you can install with:
MLX is also available on conda-forge. To install MLX with conda do:
.. code-block:: shell
pip install mlx[cuda12]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.10
conda install conda-forge::mlx
Troubleshooting
@@ -94,8 +65,6 @@ Build Requirements
Python API
^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
@@ -107,20 +76,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell
pip install .
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
pip install -e ".[dev]"
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
Run the tests with:
@@ -138,8 +107,6 @@ IDE:
C++ API
^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start
@@ -218,7 +185,6 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
@@ -247,50 +213,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting
^^^^^^^^^^^^^^^

View File

@@ -19,8 +19,6 @@ Array
array.ndim
array.shape
array.size
array.real
array.imag
array.abs
array.all
array.any

View File

@@ -1,9 +0,0 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,4 +13,3 @@ Fast
rope
scaled_dot_product_attention
metal_kernel
cuda_kernel

View File

@@ -16,8 +16,6 @@ Linear Algebra
cross
qr
svd
eigvals
eig
eigvalsh
eigh
lu

View File

@@ -27,7 +27,6 @@ simple functions.
mish
prelu
relu
relu2
relu6
selu
sigmoid

View File

@@ -50,7 +50,6 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU2
ReLU6
RNN
RoPE

View File

@@ -112,7 +112,6 @@ Operations
max
maximum
mean
median
meshgrid
min
minimum

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors"))
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -19,4 +19,3 @@ Common Optimizers
Adamax
Lion
MultiOptimizer
Muon

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x)
timeit(mx.compile(gelu), x)
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]

View File

@@ -7,29 +7,21 @@ Distributed Communication
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support several different communication backends introduced below.
.. list-table::
:widths: 20 80
:header-rows: 1
* - Backend
- Description
* - :ref:`MPI <mpi_section>`
- A full featured and mature distributed communications library.
* - :ref:`RING <ring_section>`
- Ring all reduce and all gather over TCP sockets. Always available and
usually faster than MPI.
* - :ref:`JACCL <ring_section>`
- Low latency communication with RDMA over thunderbolt. Necessary for
things like tensor parallelism.
* - :ref:`NCCL <nccl_section>`
- The backend of choice for CUDA environments.
moment we support two different communication backends:
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
.. note::
Some operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
@@ -92,8 +84,9 @@ Selecting Backend
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created.
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
both fail then a singleton group is created.
.. note::
After a distributed backend is successfully initialized :func:`init` will
@@ -117,8 +110,6 @@ The following examples aim to clarify the backend initialization logic in MLX:
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
.. _training_example:
Training Example
----------------
@@ -193,7 +184,7 @@ almost identical to the example above:
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added
grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss
@@ -201,7 +192,80 @@ almost identical to the example above:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
.. _ring_section:
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
``mpirun`` as expected. However, in the following examples we will be using
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
.. code:: shell
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
Tuning MPI All Reduce
^^^^^^^^^^^^^^^^^^^^^
.. note::
For faster all reduce consider using the ring backend either with Thunderbolt
connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing ``--mca btl_tcp_links N``.
Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
Getting Started with Ring
-------------------------
@@ -252,13 +316,22 @@ utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
To validate your connection without configuring anything
``mlx.distributed_config`` can also plot the ring using DOT format.
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
@@ -269,382 +342,3 @@ If you want to go through the process manually, the steps are as follows:
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.
.. _jaccl_section:
Getting Started with RDMA over Thunderbolt
------------------------------------------
Starting from version 26.2 RDMA over thunderbolt is available in MacOS and
enables low-latency communication between Macs with thunderbolt 5. MLX provides
the JACCL backend that uses this functionality to achieve communication latency
an order of magnitude lower than the ring backend.
.. note::
The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective
Communication Library* and it is an obvious pun to Nvidia's NCCL but also
tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt
at Apple.
Enabling RDMA
^^^^^^^^^^^^^
Until the feature matures, enabling RDMA over thunderbolt is slightly more
involved and **cannot** be done remotely even with sudo. In fact, it has to be
done in macOS recovery:
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
2. Open the Terminal by going to Utilities -> Terminal.
3. Run ``rdma_ctl enable``.
4. Reboot.
To verify that you have successfully enabled Thunderbolt RDMA you can run
``ibv_devices`` which should produce something like the following for an M3 Ultra.
.. code-block:: bash
~ % ibv_devices
device node GUID
------ ----------------
rdma_en2 8096a9d9edbaac05
rdma_en3 8196a9d9edbaac05
rdma_en5 8396a9d9edbaac05
rdma_en4 8296a9d9edbaac05
rdma_en6 8496a9d9edbaac05
rdma_en7 8596a9d9edbaac05
Defining a Mesh
^^^^^^^^^^^^^^^
The JACCL backend supports only fully connected topologies. Namely, there needs
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
the following topology visualizations, the left one is valid because there is a
connection from any node to any other node, while for the one on the right M3
Ultra 1 is not connected to M3 Ultra 2.
.. raw:: html
<div style="display: flex; text-align: center; align-items: end; font-size: 80%;">
<div>
<img src="/_static/distributed/m3-ultra-mesh.png" alt="M3 Ultra thunderbolt mesh" style="width: 55%">
<p>Fully connected mesh of four M3 Ultra.</p>
</div>
<div>
<img src="/_static/distributed/m3-ultra-mesh-broken.png" alt="M3 Ultra broken thunderbolt mesh" style="width: 55%">
<p>Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).</p>
</div>
</div>
Similar to the ring backend, the easiest way to use JACCL with MLX is to write
a JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain
- Hostnames to use for launching scripts via ssh
- An IP for rank 0 that is reachable by all nodes
- A list of rdma devices that connect each node to each other node
The following JSON defines the valid 4-node mesh from the image above.
.. code-block:: json
[
{
"ssh": "m3-ultra-1",
"ips": ["123.123.123.1"],
"rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"]
},
{
"ssh": "m3-ultra-2",
"ips": [],
"rdma": ["rdma_en5", null, "rdma_en3", "rdma_en4"]
},
{
"ssh": "m3-ultra-3",
"ips": [],
"rdma": ["rdma_en4", "rdma_en3", null, "rdma_en5"]
},
{
"ssh": "m3-ultra-4",
"ips": [],
"rdma": ["rdma_en3", "rdma_en4", "rdma_en5", null]
}
]
Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
disabling the thunderbolt bridge is still required as well as setting up
isolated local networks for each thunderbolt connection.
All of the above can be done instead via ``mlx.distributed_config``. This helper
script will
- ssh into each node
- extract the thunderbolt connectivity
- check for a valid mesh
- provide the commands to configure each node (or run them if sudo is available)
- generate the hostfile to be used with ``mlx.launch``
Putting it All Together
^^^^^^^^^^^^^^^^^^^^^^^^
For example launching a distributed MLX script that uses JACCL is fairly simple
if the nodes are reachable via ssh and have password-less sudo.
First, connect all the thunderbolt cables. Then we can verify the connections
by using the ``mlx.distributed_config`` script to visualize them.
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --dot | dot -Tpng | open -f -a Preview
After making sure that everything looks right we can auto-configure the nodes
and save the hostfile to ``m3-ultra-jaccl.json`` by running:
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --backend jaccl \
--auto-setup --output m3-ultra-jaccl.json
And now we are ready to run a distributed MLX script such as distributed inference
of a gigantic model using MLX-LM.
.. code-block::
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
/path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-V3.2-8bit --shard
.. note::
Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a
different, faster way of synchronizing between the GPU and the CPU. It is
not specific to the JACCL backend and can be used in all cases where the CPU
and GPU need to collaborate for some computation and is pretty critical for
low-latency communication since the communication is done by the CPU.
.. _nccl_section:
Getting Started with NCCL
-------------------------
MLX on CUDA environments ships with the ability to talk to `NCCL
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
communication library that supports both multi-gpu and multi-node setups.
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
it takes to run a distributed job is
.. code-block::
mlx.launch -n 8 test.py
# perfect for interactive scripts
mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
with the same ease
.. code-block::
mlx.launch --hosts my-cuda-node -n 8 test.py
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
because the cluster scheduler will be the one launching the processes. You can
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
order for the MLX NCCL backend to be initialized correctly.
.. _mpi_section:
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to `MPI
<https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ if it is installed
on the machine. Launching distributed MLX programs that use MPI can be done
with ``mpirun`` as expected. However, in the following examples we will be
using ``mlx.launch --backend mpi`` which takes care of some nuisances such as
setting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld``
shared library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
.. code:: shell
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``. Some environments use a non-standard
library filename that can be specified using the ``MPI_LIBNAME`` environment
variable. This is automatically taken care of by ``mlx.launch`` as well.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
Tuning MPI All Reduce
^^^^^^^^^^^^^^^^^^^^^
.. note::
For faster all reduce consider using the ring backend either with Thunderbolt
connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing ``--mca btl_tcp_links N``.
Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
.. _no_mlx_launch:
Distributed Without ``mlx.launch``
----------------------------------
None of the implementations of the distributed backends require launching with
``mlx.launch``. The script simply connects to each host. Starts a process per
rank and sets up the necessary environment variables before delegating to your
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
for more details.
For many use-cases this will be the easiest way to perform distributed
computations in MLX. However, there may be reasons that you cannot or should
not use ``mlx.launch``. A common such case is the use of a scheduler that
starts all the processes for you on machines undetermined at the time of
scheduling the job.
Below we list the environment variables required to use each backend.
Ring
^^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
ports for each rank to listen to, something like the following:
.. code-block:: json
[
["123.123.1.1:5000", "123.123.1.2:5000"],
["123.123.2.1:5000", "123.123.2.2:5000"],
["123.123.3.1:5000", "123.123.3.2:5000"],
["123.123.4.1:5000", "123.123.4.2:5000"]
]
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
from the distributed backend.
JACCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
to all the other ranks connect to in order to establish the RDMA connections.
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
ibverbs device names that connect each node to each other node, something like
the following:
.. code-block:: json
[
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
["rdma_en5", null, "rdma_en3", "rdma_en4"],
["rdma_en4", "rdma_en3", null, "rdma_en5"],
["rdma_en3", "rdma_en4", "rdma_en5", null]
]
NCCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_WORLD_SIZE** should contain the total number of processes that will be
launched.
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
hosts can connect to to establish the NCCL communication.
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
corresponds to this process.
Of course any `other environment variable
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
used by NCCL can be set.
.. _tips_and_tricks:
Tips and Tricks
----------------
This is a small collection of tips to help you utilize better the distributed
communication capabilities of MLX.
- *Test locally first.*
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
scale test on a single node first.
- *Batch your communication.*
As described in the :ref:`training example <training_example>`, performing a
lot of small communication can hurt performance. Copy the approach of
:func:`mlx.nn.average_gradients` to gather many small communications in a
single large one.
- *Visualize the connectivity.*
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
visualize the connnections and make sure that the cables are connected
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
- *Use the debugger.*
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
processes and gathers stdout from all processes. This makes using ``pdb`` a
breeze.

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = tree_flatten(model.parameters(), destination={})
params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -164,13 +164,13 @@ to export a function which can be used for inputs with variable shapes:
.. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")
# Ok
out, = imported_abs(mx.array([-1.0]))
# Also ok
out, = imported_abs(mx.array(-1.0))
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -70,8 +70,7 @@ Differences from NumPy
* Indexing does not perform bounds checking. Indexing out of bounds is
undefined behavior.
* Boolean mask based indexing is supported for assignment only (see
:ref:`boolean-mask-assignment`).
* Boolean mask based indexing is not yet supported.
The reason for the lack of bounds checking is that exceptions cannot propagate
from the GPU. Performing bounds checking for array indices before launching the
@@ -108,28 +107,6 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])
The first element of ``a`` could be ``4`` or ``5``.
Transformations of functions which use in-place updates are allowed and work as
expected. For example:
@@ -144,51 +121,3 @@ expected. For example:
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
and ones elsewhere.
.. _boolean-mask-assignment:
Boolean Mask Assignment
-----------------------
MLX supports boolean indices using NumPy syntax. A mask must already be
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
Other index types are routed through the standard scatter code.
.. code-block:: shell
>>> a = mx.array([1.0, 2.0, 3.0])
>>> mask = mx.array([True, False, True])
>>> updates = mx.array([5.0, 6.0])
>>> a[mask] = updates
>>> a
array([5.0, 2.0, 6.0], dtype=float32)
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
assignments, ``updates`` must provide at least as many elements as there are
``True`` entries in ``mask``.
.. code-block:: shell
>>> a = mx.zeros((2, 3))
>>> mask = mx.array([[True, False, True],
[False, False, True]])
>>> a[mask] = 1.0
>>> a
array([[1.0, 0.0, 1.0],
[0.0, 0.0, 1.0]], dtype=float32)
Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. The only
exception is a scalar boolean mask, which broadcasts to the full array.
- Any axes not covered by the mask are taken in full.
.. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
axes and therefore raise errors.

View File

@@ -7,106 +7,13 @@ Launching Distributed Programs
.. currentmodule:: mlx.core.distributed
Installing the MLX python package provides two utilities to help you configure
your Macs for distributed computation and also launch distributed programs on
multiple nodes or with many processes in a single node. These utilities are aptly named
Installing the MLX python package provides a helper script ``mlx.launch`` that
can be used to run python scripts distributed on several nodes. It allows
launching using either the MPI backend or the ring backend. See the
:doc:`distributed docs <distributed>` for the different backends.
- ``mlx.launch``
- ``mlx.distributed_config``
See the :doc:`distributed docs <distributed>` for an introduction and
getting-started guides to the various backends.
``mlx.distributed_config``
---------------------------
Unless you are launching distributed jobs locally for development or multi-gpu
CUDA environments, then you have several Macs that you need to configure for
distributed communication with MLX.
``mlx.distributed_config`` aims to automate the process of configuring the
network interfaces (especially for communication over thunderbolt) and also
creating the hostfile to be used with ``mlx.launch``.
We will analyse 3 cases of using ``mlx.distributed_config``
1. RDMA over thunderbolt using JACCL
2. TCP/IP over thunderbolt using the ring backend
3. TCP/IP over ethernet using the ring backend
JACCL
^^^^^^^
After following :ref:`the steps to enable RDMA <jaccl_section>` you can run the
following command to configure the nodes and create the hostfile.
.. code-block::
mlx.distributed_config --verbose --backend jaccl \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \
--auto-setup --output m3-ultra-jaccl.json
Let's walk through the steps that the script takes to configure the nodes.
1. Ssh to all nodes to verify that they are reachable
2. Extract the thunderbolt connectivity. Namely run commands on each node to
calculate which node is connected to which other node.
3. Verify that we have a valid fully connected mesh
4. Check that RDMA is enabled
5. Extract the ethernet IP from interface en0
6. Disable the thunderbolt bridge and set up peer to peer networks for each
thunderbolt cable
7. Write the hostfile
Knowing the above steps allows you to manually configure the nodes but also
debug any configuration issue. For instance changing the Ethernet IP to a
different interface directly in the config is possible (as long as it is
reachable from all nodes).
The ``--auto-setup`` argument requires password-less sudo on each node. If it
isn't available then the configuration script will print commands to be run on
each node.
Ring over thunderbolt
^^^^^^^^^^^^^^^^^^^^^
Setting up a ring backend over thunderbolt only requires changing the
``--backend`` from ``jaccl`` to ``ring``.
The steps are very similar with the main difference being that instead of
verifying that the nodes are fully connected, the script attempts to identify a
ring topology (or multiple rings).
Ring over Ethernet
^^^^^^^^^^^^^^^^^^
Configuring the ring backend over ethernet doesn't require setting up network
interface and as such it simply extracts the ``en0`` IP from each node and
writes the hostfile.
Debugging cable connections
^^^^^^^^^^^^^^^^^^^^^^^^^^^
``mlx.distributed_config`` can help you debug the connectivity of your nodes
over thunderbolt by exporting a graph of the connections.
Running
.. code-block::
mlx.distributed_config --verbose \
--hosts host1,host2,host3,host4 \
--over thunderbolt --dot
will export a `GraphViz <https://graphviz.org>`_ representation of the
connections between the nodes which makes it very easy to figure out which
cable is not connected correctly.
See :ref:`the JACCL section <jaccl_section>` for an example.
``mlx.launch``
--------------
Usage
-----
The minimal usage example of ``mlx.launch`` is simply
@@ -126,10 +33,6 @@ the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
It also takes care of forwarding the output of each remote process to stdout
and stderr respectively.
Importantly, it also broadcasts stdin to each process which enables interactive
programs to work in distributed mode as well as debugging using the interactive
debugger.
Providing Hosts
^^^^^^^^^^^^^^^^
@@ -160,62 +63,10 @@ host and on the same path. A good checklist to debug errors is the following:
``mlx.launch --print-python`` to see what that path is.
* the script you want to run is available on all hosts at the same path
If you are launching from a node with a completely different setup than the
nodes that the program will run on, you can specify ``--no-verify-script`` so
that ``mlx.launch`` does not attempt to verify that the executable and script
exist locally before launching the distributed job.
.. _ring_specifics:
Ring Specifics
^^^^^^^^^^^^^^
The :ref:`ring <ring_section>` backend, which is also the default
backend, can be explicitly selected with the argument ``--backend ring``. The
ring backend has some specific requirements and arguments that are different to
other backends:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.
.. _jaccl_specifics:
JACCL Specifics
^^^^^^^^^^^^^^^^
The :ref:`JACCL <jaccl_section>` backend can be selected with the argument
``--backend jaccl``. A hostfile is necessary to launch with this backend
because it needs to contain the RDMA devices connecting each node to each other
node.
NCCL Specifics
^^^^^^^^^^^^^^
The :ref:`NCCL <nccl_section>` backend is the default backend for CUDA
environments. When launching from a Mac to a Linux machine with CUDA then the
backend should be selected using ``--backend nccl``.
The ``--repeat-hosts, -n`` argument should be used to launch multi-node and
multi-gpu jobs. For instance
.. code-block::
mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh
will attempt to launch 16 processes, 8 on each node that will all run
``my-job.sh``.
.. _mpi_specifics:
MPI Specifics
^^^^^^^^^^^^^
-------------
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
@@ -232,3 +83,23 @@ to choose a specific interface for the byte-transfer-layer of MPI we can call
.. code:: shell
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
.. _ring_specifics:
Ring Specifics
--------------
The ring backend, which is also the default backend, can be explicitly selected
with the argument ``--backend ring``. The ring backend has some specific
requirements and arguments that are different to MPI:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2025 Apple Inc.
#include <dlfcn.h>
#include <iostream>
#include <sstream>
@@ -17,19 +16,6 @@
namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
///////////////////////////////////////////////////////////////////////////////
@@ -181,15 +167,16 @@ void Axpby::eval_gpu(
}
// Resolve name of kernel (corresponds to axpby.metal)
std::string kname = "axpby_";
kname += (contiguous_kernel ? "contiguous_" : "general_");
kname += type_to_name(out);
std::ostringstream kname;
kname << "axpby_";
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override;
/** The name of primitive. */
const char* name() const override {
return "Axpby";
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.25
mlx>=0.21.0
nanobind==2.4.0
nanobind==2.2.0

View File

@@ -3,10 +3,8 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c_cpu.shape}")
print(f"c dtype: {c_cpu.dtype}")
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

View File

@@ -1,6 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
@@ -20,7 +21,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
@@ -54,9 +55,6 @@ endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)

24
mlx/allocator.cpp Normal file
View File

@@ -0,0 +1,24 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
explicit Buffer(void* ptr) : ptr_(ptr) {};
Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();
@@ -28,16 +28,16 @@ class Buffer {
};
};
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
virtual Buffer make_buffer(void* ptr, size_t size) {
return Buffer{nullptr};
};
virtual void release(Buffer buffer) {}
Allocator() = default;
Allocator(const Allocator& other) = delete;
@@ -49,25 +49,4 @@ class Allocator {
Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
// Make a Buffer from a raw pointer of the given size without a copy. If a
// no-copy conversion is not possible then the returned buffer.ptr() will be
// nullptr. Any buffer created with this function must be released with
// release(buffer)
inline Buffer make_buffer(void* ptr, size_t size) {
return allocator().make_buffer(ptr, size);
};
// Release a buffer from the allocator made with make_buffer
inline void release(Buffer buffer) {
allocator().release(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->offset = other.array_desc_->offset;
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
@@ -82,28 +82,6 @@ array::array(std::initializer_list<int> data, Dtype dtype)
init(data.begin());
}
array::array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
auto buffer = allocator::make_buffer(data, nbytes());
if (buffer.ptr() == nullptr) {
set_data(allocator::malloc(nbytes()));
auto ptr = static_cast<char*>(data);
std::copy(ptr, ptr + nbytes(), this->data<char>());
deleter(data);
} else {
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
auto ptr = buffer.ptr();
allocator::release(buffer);
return deleter(ptr);
};
set_data(buffer, std::move(wrapped_deleter));
}
}
/* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
@@ -163,7 +141,7 @@ bool array::is_tracer() const {
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->offset = 0;
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
@@ -178,7 +156,7 @@ void array::set_data(
Flags flags,
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->offset = 0;
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
@@ -189,13 +167,14 @@ void array::copy_shared_buffer(
const Strides& strides,
Flags flags,
size_t data_size,
int64_t offset /* = 0 */) {
size_t offset /* = 0 */) {
array_desc_->data = other.array_desc_->data;
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
array_desc_->offset =
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
}
void array::copy_shared_buffer(const array& other) {
@@ -262,8 +241,8 @@ array::ArrayDesc::ArrayDesc(
std::vector<array> inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
init();
}

View File

@@ -10,7 +10,6 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
#include "mlx/small_vector.h"
namespace mlx::core {
@@ -19,8 +18,8 @@ class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t;
using Shape = SmallVector<ShapeElem>;
using Strides = SmallVector<int64_t>;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc
@@ -57,16 +56,6 @@ class array {
Shape shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a raw pointer. The constructor will attempt to use the
* input data without a copy. The deleter will be called when the array no
* longer needs the underlying memory - after the array is destroyed in the
* no-copy case and after the copy otherwise. */
explicit array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter);
/* Build an array from a buffer */
explicit array(
allocator::Buffer data,
@@ -235,10 +224,6 @@ class array {
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
Data(Data&& o) : buffer(o.buffer), d(o.d) {
o.buffer = allocator::Buffer(nullptr);
o.d = [](allocator::Buffer) {};
}
~Data() {
d(buffer);
}
@@ -304,11 +289,6 @@ class array {
return array_desc_->siblings;
}
/** The array's position in the sibling list. */
int sibling_position() const {
return array_desc_->position;
}
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
@@ -364,23 +344,15 @@ class array {
return array_desc_->data;
}
// Return a raw pointer to the arrays data. This function may do a copy if
// the underlying buffer is not accessible on the CPU. When accessing the
// data for GPU kernels, be sure to use the correct method / function for the
// given backend to access the GPU pointer.
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
return reinterpret_cast<T*>(
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
return static_cast<T*>(array_desc_->data_ptr);
}
template <typename T>
const T* data() const {
return const_cast<array&>(*this).data<T>();
}
int64_t offset() const {
return array_desc_->offset;
return static_cast<T*>(array_desc_->data_ptr);
}
enum Status {
@@ -449,7 +421,7 @@ class array {
const Strides& strides,
Flags flags,
size_t data_size,
int64_t offset = 0);
size_t offset = 0);
void copy_shared_buffer(const array& other);
@@ -484,8 +456,8 @@ class array {
// can share the underlying data buffer.
std::shared_ptr<Data> data;
// Offset from beginning of data pointer
int64_t offset{0};
// Properly offset data pointer
void* data_ptr{nullptr};
// The size in elements of the data buffer the array accesses
size_t data_size;

View File

@@ -6,4 +6,5 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transpose.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)

View File

@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
BinaryOpType bopt) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
out.set_data(
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
mallocfn(b.data_size() * out.itemsize()),
allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
mallocfn(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
mallocfn(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(mallocfn(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(allocator::malloc(0));
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);

View File

@@ -1,157 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cassert>
#include <functional>
#include <map>
namespace mlx::core {
template <typename T>
class BufferCache {
public:
BufferCache(
size_t page_size,
std::function<size_t(T*)> get_size,
std::function<void(T*)> free)
: page_size_(page_size),
get_size_(std::move(get_size)),
free_(std::move(free)) {}
~BufferCache() {
clear();
}
BufferCache(const BufferCache&) = delete;
BufferCache& operator=(const BufferCache&) = delete;
T* reuse_from_cache(size_t size) {
// Find the closest buffer in pool.
auto it = buffer_pool_.lower_bound(size);
if (it == buffer_pool_.end() ||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
return nullptr;
}
// Collect from the cache.
T* buf = it->second->buf;
pool_size_ -= it->first;
// Remove from record.
remove_from_list(it->second);
buffer_pool_.erase(it);
return buf;
}
void recycle_to_cache(T* buf) {
assert(buf);
// Add to cache.
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
size_t size = get_size_(buf);
pool_size_ += size;
buffer_pool_.emplace(size, bh);
}
int release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
// Release buffer.
size_t size = get_size_(tail_->buf);
total_bytes_freed += size;
free_(tail_->buf);
n_release++;
// Remove from record.
auto its = buffer_pool_.equal_range(size);
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
return el.second == tail_;
});
assert(it != buffer_pool_.end());
buffer_pool_.erase(it);
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
int clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
free_(holder->buf);
n_release++;
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
size_t cache_size() const {
return pool_size_;
}
size_t page_size() const {
return page_size_;
}
private:
struct BufferHolder {
public:
explicit BufferHolder(T* buf_) : buf(buf_) {}
BufferHolder* prev{nullptr};
BufferHolder* next{nullptr};
T* buf;
};
void add_at_head(BufferHolder* to_add) {
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void remove_from_list(BufferHolder* to_remove) {
if (to_remove->prev && to_remove->next) { // if middle
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // if tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // if head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // if only element
head_ = nullptr;
tail_ = nullptr;
}
delete to_remove;
}
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_{nullptr};
BufferHolder* tail_{nullptr};
size_t pool_size_{0};
const size_t page_size_;
std::function<size_t(T*)> get_size_;
std::function<void(T*)> free_;
};
} // namespace mlx::core

View File

@@ -2,6 +2,7 @@
#include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/transpose.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@@ -19,26 +20,19 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
"AsStrided must be used with row contiguous arrays only.");
}
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
}
// Calculate the contiguity based on the given shape and strides
auto [ds, rc, cc] = check_contiguity(shape_, strides_);
auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
flags.contiguous = rc || cc;
flags.row_contiguous = rc;
flags.col_contiguous = cc;
// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();
// There is no easy way to compute the actual data size so we use out.size()
// when the array is not contiguous.
size_t data_size = flags.contiguous ? ds : out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
@@ -270,36 +264,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Strides out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
int64_t f_stride = 1;
int64_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
transpose(inputs[0], out, axes_);
}
} // namespace mlx::core

View File

@@ -1,7 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -14,8 +15,6 @@ void print_constant(std::ostream& os, const array& x) {
return print_float_constant<float16_t>(os, x);
case bfloat16:
return print_float_constant<bfloat16_t>(os, x);
case float64:
return print_float_constant<double>(os, x);
case complex64:
return print_complex_constant<complex64_t>(os, x);
case int8:
@@ -52,8 +51,6 @@ std::string get_type_string(Dtype d) {
return "float16_t";
case bfloat16:
return "bfloat16_t";
case float64:
return "double";
case complex64:
return "complex64_t";
case bool_:
@@ -82,6 +79,55 @@ std::string get_type_string(Dtype d) {
}
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const Shape& shape) {
@@ -113,10 +159,9 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous,
const std::function<allocator::Buffer(size_t)>&
mallocfn /* = allocator::malloc */) {
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous) {
if (contiguous) {
int o = 0;
Strides strides;
@@ -130,7 +175,8 @@ void compiled_allocate_outputs(
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && is_constant(i)) {
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
@@ -142,7 +188,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
mallocfn(data_size * outputs[o].itemsize()),
allocator::malloc(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@@ -158,86 +204,16 @@ void compiled_allocate_outputs(
// - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
is_constant(i)) {
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
}
}
}
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant) {
const Shape& shape = out.shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
if (contiguous) {
return {true, shape, {}};
}
std::vector<Strides> strides_vec{out.strides()};
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants.
if (is_constant(i)) {
continue;
}
// Skip scalar inputs.
const auto& x = inputs[i];
if (is_scalar(x)) {
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
size_t j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides_vec.push_back(std::move(xstrides));
}
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
}
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
size_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
size_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
return max_size > UINT32_MAX;
}
}
} // namespace mlx::core

View File

@@ -1,8 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <functional>
#include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/primitives.h"
@@ -13,17 +14,19 @@ inline bool is_static_cast(const Primitive& p) {
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d);
template <typename T>
void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
if constexpr (std::is_same_v<T, double>) {
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
} else {
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
}
os << x.item<T>() << std::setprecision(old_precision);
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< x.item<T>() << std::setprecision(old_precision);
}
template <typename T>
@@ -57,21 +60,8 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous,
const std::function<allocator::Buffer(size_t)>& mallocfn =
allocator::malloc);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant);
// Return whether the kernel should use large index.
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous);
} // namespace mlx::core

View File

@@ -2,7 +2,7 @@
#pragma once
#include "mlx/backend/common/utils.h"
#include "mlx/array.h"
namespace mlx::core {
@@ -22,27 +22,23 @@ enum class CopyType {
GeneralGeneral
};
inline bool set_copy_output_data(
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (is_donatable(in, out)) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
return true;
} else {
out.set_data(
mallocfn(in.data_size() * out.itemsize()),
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(mallocfn(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
return false;
}
}

View File

@@ -1,67 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
#include <sstream>
namespace mlx::core {
inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto a_batch_strides = batch_strides[0];
auto b_batch_strides = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
a_batch_strides.push_back(0);
b_batch_strides.push_back(0);
}
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
}
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
} // namespace mlx::core

View File

@@ -5,9 +5,11 @@
namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
@@ -17,15 +19,6 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides);
}
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -51,9 +51,5 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);
} // namespace mlx::core

View File

@@ -14,13 +14,17 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
}
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const Strides& out_strides,
int64_t data_offset,
size_t data_offset,
size_t data_size,
array& out) {
// Compute row/col contiguity
@@ -41,30 +45,23 @@ void slice(
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(allocator::malloc(0));
out.set_data(nullptr);
return;
}
// Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
// Get the location of the end based on the inp strides and out.shape()
int64_t low_idx = 0;
int64_t high_idx = 0;
for (int i = 0; i < inp_strides.size(); ++i) {
auto delta = inp_strides[i] * (out.shape()[i] - 1);
if (inp_strides[i] > 0) {
high_idx += delta;
} else {
low_idx += delta;
int64_t data_end = 1;
for (int i = 0; i < start_indices.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
int64_t data_size = (high_idx - low_idx) + 1;
if (data_size < 0) {
std::ostringstream msg;
msg << "[slice] Computed invalid data size: " << data_size << ".";
throw std::runtime_error(msg.str());
if (data_end < 0) {
data_end += in.data_size();
}
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}

View File

@@ -11,8 +11,6 @@ namespace mlx::core {
enum class TernaryOpType {
ScalarScalarScalar,
VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General,
};
@@ -27,14 +25,6 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else {
topt = TernaryOpType::General;
}
@@ -46,8 +36,7 @@ inline void set_ternary_op_output_data(
const array& b,
const array& c,
array& out,
TernaryOpType topt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
TernaryOpType topt) {
auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) {
out.copy_shared_buffer(x);
@@ -58,25 +47,24 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
out.set_data(
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
mallocfn(out.itemsize() * b.data_size()),
allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
}
break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General:
// Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(mallocfn(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -0,0 +1,57 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void transpose(const array& in, array& out, const std::vector<int>& axes) {
Strides out_strides(out.ndim());
for (int ax = 0; ax < axes.size(); ++ax) {
out_strides[ax] = in.strides()[axes[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
auto [_, rc, cc] = check_contiguity(out.shape(), out_strides);
flags.row_contiguous = rc;
flags.col_contiguous = cc;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void as_transposed(array& out, const std::vector<int>& axes) {
assert(out.data_size() == out.size() && out.flags().contiguous);
// Calculate the contiguous strides.
Strides strides(out.ndim(), 1);
for (int i = out.ndim() - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * out.shape(i);
}
// Calculate the new strides for transposing.
Strides new_strides;
new_strides.reserve(out.ndim());
for (auto ax : axes) {
new_strides.push_back(strides[ax]);
}
auto [ds, rc, cc] = check_contiguity(out.shape(), new_strides);
auto flags = out.flags();
flags.row_contiguous = rc;
flags.col_contiguous = cc;
out.copy_shared_buffer(out, new_strides, flags, ds);
}
} // namespace mlx::core

View File

@@ -0,0 +1,12 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void transpose(const array& in, array& out, const std::vector<int>& axes);
void as_transposed(array& out, const std::vector<int>& axes);
} // namespace mlx::core

View File

@@ -1,29 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
inline void set_unary_output_data(
const array& in,
array& out,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(mallocfn(out.nbytes()));
}
}
} // namespace mlx::core

View File

@@ -1,22 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include "mlx/backend/common/utils.h"
namespace mlx::core {
std::filesystem::path current_binary_dir() {
static std::filesystem::path binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path();
}();
return binary_dir;
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,
@@ -114,118 +101,4 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
}
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == pow2) {
break;
}
}
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
}
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
// Dims with strides of 0 are ignored as they
// correspond to broadcasted dimensions
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor) {
// Compute the 2d grid dimensions such that the total size of the grid is
// divided by divisor.
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
// No need to add this shape we can just remove it from the divisor.
if (divisor % shape[i] == 0) {
divisor /= shape[i];
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
if (divisor > 1) {
if (grid_x % divisor == 0) {
grid_x /= divisor;
divisor = 1;
} else if (grid_y % divisor == 0) {
grid_y /= divisor;
divisor = 1;
}
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
if (divisor > 1) {
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
auto gx = (dim0 + bx - 1) / bx;
auto gy = (dim1 + by - 1) / by;
auto gz = (dim2 + bz - 1) / bz;
return std::make_pair(
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
} // namespace mlx::core

View File

@@ -2,17 +2,12 @@
#pragma once
#include <filesystem>
#include <tuple>
#include <vector>
#include "mlx/array.h"
namespace mlx::core {
// Return the directory that contains current shared library.
std::filesystem::path current_binary_dir();
inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int64_t loc = 0;
@@ -75,31 +70,6 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
const array& a,
int64_t size_cap = std::numeric_limits<int32_t>::max());
// Compute the thread block dimensions which fit the given
// input dimensions.
// - The thread block dimensions will be powers of two
// - The thread block size will be less than 2^pow2
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
// Computes a 2D grid where each element is < UINT_MAX
// Assumes:
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
// - shape and strides correspond to a contiguous (no holes) but
// possibly broadcasted array
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
// Same as above but we do an implicit division with divisor.
// Basically, equivalent to factorizing
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor);
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();
@@ -162,6 +132,11 @@ struct ContiguousIterator {
};
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
// Conditions for {row/col}_contiguous
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
size_t no_broadcast_data_size = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
@@ -195,11 +170,4 @@ void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}
} // namespace mlx::core

View File

@@ -46,7 +46,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp

View File

@@ -14,8 +14,10 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = remove_index(in.strides(), axis);
Shape shape = remove_index(in.shape(), axis);
Strides strides = in.strides();
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();

View File

@@ -14,11 +14,233 @@
namespace mlx::core {
namespace {
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Add(), stream());
binary(a, b, out, detail::Add(), stream());
}
void DivMod::eval_cpu(
@@ -102,14 +324,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Divide(), stream());
binary(a, b, out, detail::Divide(), stream());
}
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Remainder(), stream());
binary(a, b, out, detail::Remainder(), stream());
}
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -150,90 +372,89 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
}
});
} else {
comparison_op_cpu(a, b, out, detail::Equal(), stream());
comparison_op(a, b, out, detail::Equal(), stream());
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op_cpu(
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
binary_float(a, b, out, detail::LogAddExp(), stream());
}
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
binary(in1, in2, out, detail::LogicalAnd(), stream());
}
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
binary(in1, in2, out, detail::LogicalOr(), stream());
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Maximum(), stream());
binary(a, b, out, detail::Maximum(), stream());
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Minimum(), stream());
binary(a, b, out, detail::Minimum(), stream());
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Multiply(), stream());
binary(a, b, out, detail::Multiply(), stream());
}
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Power(), stream());
binary(a, b, out, detail::Power(), stream());
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Subtract(), stream());
binary(a, b, out, detail::Subtract(), stream());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -242,19 +463,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
switch (op_) {
case BitwiseBinary::And:
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
binary_int(a, b, out, detail::BitwiseAnd(), stream());
break;
case BitwiseBinary::Or:
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
binary_int(a, b, out, detail::BitwiseOr(), stream());
break;
case BitwiseBinary::Xor:
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
binary_int(a, b, out, detail::BitwiseXor(), stream());
break;
case BitwiseBinary::LeftShift:
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
binary_int(a, b, out, detail::LeftShift(), stream());
break;
case BitwiseBinary::RightShift:
binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
binary_int(a, b, out, detail::RightShift(), stream());
break;
}
}
@@ -263,7 +484,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
binary_float(a, b, out, detail::ArcTan2(), stream());
}
} // namespace mlx::core

View File

@@ -7,7 +7,6 @@
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
@@ -291,227 +290,4 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt);
}
template <typename Op>
void binary_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace mlx::core

View File

@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// The decomposition is computed in place, so just copy the input to the
// output.
copy_cpu(
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -15,7 +15,6 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core {
@@ -95,11 +94,7 @@ void* compile(
kernel_file_name = kernel_name;
}
auto output_dir =
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
auto output_dir = std::filesystem::temp_directory_path();
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();
@@ -151,9 +146,18 @@ inline void build_kernel(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::function<bool(size_t)>& is_constant,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
#ifdef _MSC_VER
@@ -162,28 +166,25 @@ inline void build_kernel(
#endif
// Start the kernel
os << "void " << kernel_name
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
os << "void " << kernel_name << "(void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
// Skip constants from the input list
if (is_constant(i)) {
if (is_constant(x)) {
continue;
}
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const int64_t* " << xname << "_strides = strides["
<< strides_index++ << "];" << std::endl;
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
}
}
@@ -193,8 +194,10 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output size
if (contiguous) {
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
@@ -208,11 +211,10 @@ inline void build_kernel(
}
// Read the inputs in tmps
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(i)) {
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
@@ -236,7 +238,7 @@ inline void build_kernel(
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
os << x.primitive().name();
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
@@ -262,9 +264,8 @@ inline void build_kernel(
} else {
for (int d = ndim - 1; d >= 0; --d) {
// Update pointers
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
if (is_constant(i) || is_scalar(x)) {
for (auto& x : inputs) {
if (is_constant(x) || is_scalar(x)) {
continue;
}
auto& xname = namer.get_name(x);
@@ -286,33 +287,65 @@ inline void build_kernel(
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream());
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Collect function input arguments.
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) {
std::vector<std::vector<size_t>> strides;
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
const auto& x = inputs[i];
auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
}
// Get the kernel name from the lib
int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
kernel_name += std::to_string(ndim);
kernel_name += std::to_string(shape.size());
}
// Get the function
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
auto fn_ptr = compile(kernel_name, [&]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
@@ -322,7 +355,7 @@ void Compiled::eval_cpu(
inputs_,
outputs_,
tape_,
is_constant_,
constant_ids_,
contiguous,
ndim);
// Close extern "C"
@@ -330,26 +363,26 @@ void Compiled::eval_cpu(
return kernel.str();
});
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous);
for (auto& x : outputs) {
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
if (contiguous) {
Shape out_shape;
if (!contiguous) {
out_shape = outputs[0].shape();
args.push_back((void*)out_shape.data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
encoder.dispatch([fun,
args = std::move(args),
strides = std::move(strides),
shape = std::move(shape)]() mutable {
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch(
[fun,
args = std::move(args),
strides = std::move(strides),
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
}
} // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@@ -295,11 +295,7 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream) {
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
@@ -309,7 +305,7 @@ void copy_cpu_inplace(
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
}
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
bool donated = set_copy_output_data(src, dst, ctype);
if (donated && src.dtype() == dst.dtype()) {
// If the output has the same type as the input then there is nothing to
@@ -319,10 +315,10 @@ void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_cpu_inplace(src, dst, ctype, stream);
copy_inplace(src, dst, ctype, stream);
}
void copy_cpu_inplace(
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
@@ -377,10 +373,4 @@ void copy_cpu_inplace(
});
}
array contiguous_copy_cpu(const array& arr, Stream stream) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return arr_copy;
}
} // namespace mlx::core

View File

@@ -10,14 +10,10 @@
namespace mlx::core {
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream);
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace(
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
@@ -30,7 +26,4 @@ void copy_cpu_inplace(
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_cpu(const array& arr, Stream stream);
} // namespace mlx::core

View File

@@ -13,7 +13,9 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) {
return {arr, false};
} else {
return {contiguous_copy_cpu(arr, stream), true};
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
}
};
@@ -32,7 +34,8 @@ void AllReduce::eval_cpu(
}
return in;
} else {
array arr_copy = contiguous_copy_cpu(in, s);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy);
return arr_copy;
}
@@ -95,9 +98,4 @@ void Recv::eval_cpu(
distributed::detail::recv(group(), outputs[0], src_, stream());
}
void ReduceScatter::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
}
} // namespace mlx::core::distributed

View File

@@ -1,281 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
complex64_t to_complex(T r, T i) {
return {static_cast<float>(r), static_cast<float>(i)};
}
template <typename T, class Enable = void>
struct EigWork {};
template <typename T>
struct EigWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using O = complex64_t;
char jobl;
char jobr;
int N;
int lwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
T work;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
if (compute_eigenvectors) {
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
}
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, O* values, O* vectors) {
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
T* vec_tmp = nullptr;
if (vectors) {
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
}
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
eig_tmp,
eig_tmp + N,
vectors ? vec_tmp : nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
work,
&lwork,
&info);
for (int i = 0; i < N; ++i) {
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
}
if (vectors) {
for (int i = 0; i < N; ++i) {
if (values[i].imag() != 0) {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] =
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
vectors[(i + 1) * N + j] =
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
}
}
}
}
}
};
template <>
struct EigWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
using O = T;
char jobl;
char jobr;
int N;
int lwork;
int lrwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
T work;
R rwork;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&rwork,
&info);
lwork = static_cast<int>(work.real());
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
}
void run(T* a, T* values, T* vectors) {
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
values,
vectors,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&info);
}
};
template <typename T>
void eig_impl(
array& a,
array& vectors,
array& values,
bool compute_eigenvectors,
Stream stream) {
auto a_ptr = a.data<T>();
auto val_ptr = values.data<complex64_t>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
complex64_t* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<complex64_t>();
}
encoder.dispatch([a_ptr,
val_ptr,
vec_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
for (size_t i = 0; i < size / (N * N); ++i) {
work.run(a_ptr, val_ptr, vec_ptr);
a_ptr += N * N;
val_ptr += N;
if (vec_ptr) {
vec_ptr += N * N;
}
if (work.info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< work.info;
throw std::runtime_error(msg.str());
}
}
});
encoder.add_temporary(a);
}
} // namespace
void Eig::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy_cpu(
a,
a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
values.set_data(allocator::malloc(values.nbytes()));
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.set_data(
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
}
switch (a.dtype()) {
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case float64:
eig_impl<double>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case complex64:
eig_impl<std::complex<float>>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error(
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
}
}
} // namespace mlx::core

View File

@@ -12,133 +12,6 @@ namespace mlx::core {
namespace {
template <typename T, class Enable = void>
struct EighWork {};
template <typename T>
struct EighWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
char jobz;
char uplo;
int N;
int lwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, T* values) {
syevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<int*>(buffers[1].buffer.raw_ptr()),
&liwork,
&info);
}
};
template <>
struct EighWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
char jobz;
char uplo;
int N;
int lwork;
int lrwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
T work;
R rwork;
int iwork;
heevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&rwork,
&lrwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work.real());
lrwork = static_cast<int>(rwork);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, R* values) {
heevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&lrwork,
static_cast<int*>(buffers[2].buffer.raw_ptr()),
&liwork,
&info);
if (jobz == 'V') {
// We have pre-transposed the vectors but we also must conjugate them
// when they are complex.
//
// We could vectorize this but it is so fast in comparison to heevd that
// it doesn't really matter.
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
*vectors = std::conj(*vectors);
vectors++;
}
}
}
}
};
template <typename T>
void eigh_impl(
array& vectors,
@@ -146,10 +19,8 @@ void eigh_impl(
const std::string& uplo,
bool compute_eigenvectors,
Stream stream) {
using R = typename EighWork<T>::R;
auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<R>();
auto eig_ptr = values.data<T>();
char jobz = compute_eigenvectors ? 'V' : 'N';
auto& encoder = cpu::get_command_encoder(stream);
@@ -162,17 +33,49 @@ void eigh_impl(
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
EighWork<T> work(jobz, uplo, N);
int lwork = -1;
int liwork = -1;
int info;
{
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
}
// Work loop
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
work.run(vec_ptr, eig_ptr);
syevd<T>(
&jobz,
&uplo,
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
vec_ptr += N * N;
eig_ptr += N;
if (work.info != 0) {
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< work.info;
<< info;
throw std::runtime_error(msg.str());
}
}
@@ -196,7 +99,7 @@ void Eigh::eval_cpu(
values.set_data(allocator::malloc(values.nbytes()));
copy_cpu(
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
@@ -228,10 +131,6 @@ void Eigh::eval_cpu(
eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
case complex64:
eigh_impl<std::complex<float>>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error(
"[Eigh::eval_cpu] only supports float32 or float64.");

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h>
#include "mlx/array.h"
@@ -48,15 +49,9 @@ void matmul_bnns(
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
}
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,

View File

@@ -88,47 +88,4 @@ void matmul<double>(
}
}
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core

View File

@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.flags().row_contiguous && in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
copy_cpu(
copy(
in,
out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -257,11 +257,15 @@ void gather_axis(
const array& ind,
array& out,
const int axis) {
auto shape = remove_index(ind.shape(), axis);
ContiguousIterator ind_it(
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
ContiguousIterator src_it(
shape, remove_index(src.strides(), axis), src.ndim() - 1);
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>();
@@ -517,7 +521,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream());
copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds;
@@ -581,11 +585,15 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
template <typename T, typename IdxT, typename OpT>
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto shape = remove_index(idx.shape(), axis);
ContiguousIterator idx_it(
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
ContiguousIterator upd_it(
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>();
@@ -686,7 +694,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream());
copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx);
@@ -747,108 +755,4 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
});
}
template <typename T>
void masked_scatter_impl(const array& mask, const array& src, array& out) {
ContiguousIterator mask_it(mask);
ContiguousIterator src_it(src);
ContiguousIterator out_it(out);
const bool* mask_ptr = mask.data<bool>();
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
const size_t batch_count = mask.shape(0);
const size_t mask_batch_size = mask.size() / batch_count;
const size_t src_batch_size = src.size() / batch_count;
for (uint b = 0; b < batch_count; ++b) {
size_t src_consumed = 0;
src_it.seek(b * src_batch_size);
for (size_t i = 0; i < mask_batch_size; ++i) {
if (mask_ptr[mask_it.loc]) {
if (src_consumed >= src_batch_size) {
throw std::runtime_error(
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
}
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
src_it.step();
++src_consumed;
}
mask_it.step();
out_it.step();
}
}
}
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& dst = inputs[0];
auto& mask = inputs[1];
auto& src = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(dst, out, ctype, stream());
if (mask.size() == 0) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(mask);
encoder.set_input_array(src);
encoder.set_output_array(out);
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
src = array::unsafe_weak_copy(src),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
masked_scatter_impl<bool>(mask, src, out);
break;
case uint8:
masked_scatter_impl<uint8_t>(mask, src, out);
break;
case uint16:
masked_scatter_impl<uint16_t>(mask, src, out);
break;
case uint32:
masked_scatter_impl<uint32_t>(mask, src, out);
break;
case uint64:
masked_scatter_impl<uint64_t>(mask, src, out);
break;
case int8:
masked_scatter_impl<int8_t>(mask, src, out);
break;
case int16:
masked_scatter_impl<int16_t>(mask, src, out);
break;
case int32:
masked_scatter_impl<int32_t>(mask, src, out);
break;
case int64:
masked_scatter_impl<int64_t>(mask, src, out);
break;
case float16:
masked_scatter_impl<float16_t>(mask, src, out);
break;
case float32:
masked_scatter_impl<float>(mask, src, out);
break;
case float64:
masked_scatter_impl<double>(mask, src, out);
break;
case bfloat16:
masked_scatter_impl<bfloat16_t>(mask, src, out);
break;
case complex64:
masked_scatter_impl<complex64_t>(mask, src, out);
break;
}
});
}
} // namespace mlx::core

View File

@@ -115,7 +115,7 @@ void inverse_impl(
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy_cpu(
copy(
a,
inv,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include <algorithm>
#include <sstream>
#include <vector>

Some files were not shown because too many files have changed in this diff Show More