Compare commits

..

8 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a22d0bf273 Add stricter condition to matrix sdpa 2025-08-06 19:51:14 -07:00
Jagrit Digani
99d8de8445 Fix cudnn routing 2025-08-06 15:05:58 -07:00
Jagrit Digani
c66b76a8c8 Update routing 2025-08-06 15:01:15 -07:00
Jagrit Digani
f81edd184f Complete 2 pass sdpav 2025-08-06 13:57:40 -07:00
Jagrit Digani
7f8ba2a003 [WIP] 2 pass sdpav 2025-08-06 09:56:39 -07:00
Jagrit Digani
c28249b81a Add more nvtx range for debug 2025-08-06 09:56:39 -07:00
Jagrit Digani
e74bcdc5e3 Add sdpa file 2025-08-06 09:56:39 -07:00
Jagrit Digani
d8ed6c1aa3 Add base cudnn attention support 2025-08-06 09:56:39 -07:00
425 changed files with 7280 additions and 29886 deletions

724
.circleci/config.yml Normal file
View File

@@ -0,0 +1,724 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "16.2.0"
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
- run:
name: Install dependencies
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
uv pip install -e ".[dev]" -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
- run:
name: Install Python package
command: |
uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build example extension
command: |
source .venv/bin/activate
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source .venv/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e .
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
python_version:
type: string
default: "3.9"
xcode_version:
type: string
default: "16.2.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install twine
pip install build
- run:
name: Install Python package
command: |
source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
- store_artifacts:
path: dist/
build_linux_release:
parameters:
python_version:
type: string
default: "3.9"
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Build wheel
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release
build_dev_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

View File

@@ -1,15 +0,0 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh

View File

@@ -1,38 +0,0 @@
name: 'Build Documentation'
description: 'Build documentation'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-linux
- name: Install dependencies
shell: bash
run: |
sudo apt-get install -y doxygen
source .venv/bin/activate
pip install -r docs/requirements.txt
pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: bash
run: tar -cf artifact.tar -C docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

View File

@@ -1,40 +0,0 @@
name: 'Build Linux wheel'
description: 'Build Linux wheel'
inputs:
build-backend:
description: 'Build the backend mlx-cpu package'
type: boolean
required: false
default: false
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
- name: Generate package stubs
shell: bash
run: |
pip install -e ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
- name: Build Python package
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}

View File

@@ -1,41 +0,0 @@
name: 'Build and Test on Linux'
inputs:
toolkit:
description: 'The toolkit to build with'
required: false
default: 'cpu'
runs:
using: "composite"
steps:
- name: Install Python package
id: python_build
shell: sh
env:
DEBUG: 1
CMAKE_ARGS: >-
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
run: |
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
# There is no GPU in arm64 runner, use a common arch.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
# Can not build tests when the built executables can not run.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
fi
pip install --no-build-isolation -e ".[dev]" -v
# Pass the CMAKE_ARGS to following steps.
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build CPP only
shell: bash
run: |
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
cmake --build build -j $(nproc)

View File

@@ -1,34 +0,0 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
build-backend:
description: 'Build the backend mlx-metal package'
type: boolean
required: false
default: false
runs:
using: "composite"
steps:
- name: Build Python package
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
pip install build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w

View File

@@ -1,88 +0,0 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
runs:
using: "composite"
steps:
- name: Install dependencies
env:
DEBUG: 1
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Install tests dependencies
shell: bash -l {0}
run: |
pip install numpy torch tensorflow unittest-xml-reporting
- name: Run Python tests
shell: bash -l {0}
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
shell: bash -l {0}
run: |
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext --inplace
python test.py
- name: Build CPP only
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
shell: bash -l {0}
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
shell: bash -l {0}
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit

View File

@@ -1,86 +0,0 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
toolkit:
description: 'Which toolkit to install'
required: false
default: 'cpu'
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Use ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
max-size: 1GB
- name: Install common dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
- name: Setup Python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
pip install setuptools cmake nanobind==2.4.0
echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
- name: Install MPI
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Install CUDA toolkit
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
env:
# Note: the CI machine does not meet CUDA 13's driver requirement.
# Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
}
run: |
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
# Jetson specific. SBSA means Arm Server Base System Architecture.
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y \
libnccl2 libnccl-dev \
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
- name: CUDA packages and driver report
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

View File

@@ -1,24 +0,0 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- uses: conda-incubator/setup-miniconda@v3
with:
miniconda-version: "latest"
python-version: ${{ inputs.python-version }}

View File

@@ -1,69 +0,0 @@
name: 'Run Linux tests'
inputs:
has-gpu:
description: 'Run GPU tests'
required: false
default: false
runs:
using: "composite"
steps:
- name: Run MPI tests
shell: bash
run: |
echo "::group::MPI tests"
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
run: |
echo "::group::Distributed tests"
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::Python tests - CPU"
python -m unittest discover python/tests -v
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::Python tests - GPU"
python -m tests discover python/tests -v
echo "::endgroup::"
- name: Run CPP tests - CPU
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::CPP tests - CPU"
./build/tests/tests
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::CPP tests - GPU"
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
echo "::endgroup::"

View File

@@ -1,6 +0,0 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

View File

@@ -1,27 +0,0 @@
#!/bin/bash
set -ex
# [Setup] Install dependencies inside the container.
dnf update -y
dnf install -y \
blas-devel \
lapack-devel \
openblas-devel \
make \
cmake \
clang \
git
dnf clean all
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
export DEBUG=1
export CMAKE_C_COMPILER=/usr/bin/clang
export CMAKE_CXX_COMPILER=/usr/bin/clang++
mkdir -p build
pushd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
./tests/tests
popd

View File

@@ -1,108 +0,0 @@
name: Build and Test
on:
pull_request:
push:
branches:
- main
# For testing CI without starting a pull request:
- test/*
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:
name: Check Lint
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
name: Linux (cpu, ${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
cuda_build_and_test:
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
if: github.repository == 'ml-explore/mlx'
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
if: matrix.arch == 'x86_64'
with:
has-gpu: true
mac_build_and_test:
name: macOS (${{ matrix.macos-target }})
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
macos-target: ["14.0", "15.0"]
runs-on: [self-hosted, macos]
env:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
build_documentation:
name: Build Documentation
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -1,28 +0,0 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

View File

@@ -1,96 +0,0 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12", "3.13", "3.14"]
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7

20
.github/workflows/pull_request.yml vendored Normal file
View File

@@ -0,0 +1,20 @@
on:
pull_request:
branches:
- main
jobs:
check_lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files

View File

@@ -1,238 +0,0 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
inputs:
dev_release:
description: "Do a dev release or regular release"
required: true
default: "false"
permissions:
contents: read
jobs:
setup:
runs-on: ubuntu-latest
steps:
- name: Set publishing variables
run: echo "Publishing setup complete"
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cpu-${{ matrix.arch }}
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_cuda_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
pattern: mlx-cpu-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/

View File

@@ -1,10 +1,4 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7 rev: v19.1.7
hooks: hooks:

View File

@@ -19,17 +19,12 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a> </a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software # Third-Party Software
MLX leverages several third-party software, listed here together with MLX leverages several third-party software, listed here together with

View File

@@ -26,7 +26,6 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER) set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration ----------------------------- # ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON) option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -74,7 +73,6 @@ endif()
if(MLX_USE_CCACHE) if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache) find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM) if(CCACHE_PROGRAM)
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
@@ -89,26 +87,22 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
# Supress warnings: note: parameter passing for argument of type if(MLX_BUILD_METAL)
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC set(METAL_LIB "-framework Metal")
# 10.1 set(FOUNDATION_LIB "-framework Foundation")
target_compile_options(mlx PRIVATE -Wno-psabi) set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA) if(MLX_BUILD_CUDA)
enable_language(CUDA) enable_language(CUDA)
endif() endif()
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL AND NOT METAL_LIB)
find_library(METAL_LIB Metal) message(STATUS "Metal not found. Unable to build GPU")
find_library(FOUNDATION_LIB Foundation) set(MLX_BUILD_METAL OFF)
find_library(QUARTZ_LIB QuartzCore) set(MLX_METAL_DEBUG OFF)
if(METAL_LIB) elseif(MLX_BUILD_METAL)
message(STATUS "Metal found ${METAL_LIB}") message(STATUS "Building METAL sources")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
@@ -117,12 +111,7 @@ if(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0) if(${MACOS_SDK_VERSION} LESS 14.0)
message( message(
@@ -132,12 +121,9 @@ if(MLX_BUILD_METAL)
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}") message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip) https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS >= 14.0")
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif() endif()
execute_process( execute_process(
@@ -146,6 +132,7 @@ if(MLX_BUILD_METAL)
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
target_include_directories( target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
@@ -153,12 +140,6 @@ if(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif() endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32) if(WIN32)
if(MSVC) if(MSVC)
# GGUF does not build with MSVC. # GGUF does not build with MSVC.
@@ -186,7 +167,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
else() else()
message(STATUS "Accelerate not found, using default backend.") message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()

View File

@@ -2,7 +2,7 @@
[**Quickstart**](#quickstart) | [**Installation**](#installation) | [**Quickstart**](#quickstart) | [**Installation**](#installation) |
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) | [**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
[**Examples**](#examples) [**Examples**](#examples)
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx) [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
Some key features of MLX include: Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX - **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building `mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models. more complex models.
- **Composable function transformations**: MLX supports composable function - **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization, transformations for automatic differentiation, automatic vectorization,
and computation graph optimization. and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only - **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed. materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed - **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive. slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices - **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU). (currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks - **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory. is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported Operations on MLX arrays can be performed on any of the supported
device types without transferring data. device types without transferring data.
MLX is designed by machine learning researchers for machine learning MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient researchers. The framework is intended to be user-friendly, but still efficient
to train and deploy models. The design of the framework itself is also to train and deploy models. The design of the framework itself is also
conceptually simple. We intend to make it easy for researchers to extend and conceptually simple. We intend to make it easy for researchers to extend and
improve MLX with the goal of quickly exploring new ideas. improve MLX with the goal of quickly exploring new ideas.
The design of MLX is inspired by frameworks like The design of MLX is inspired by frameworks like
[NumPy](https://numpy.org/doc/stable/index.html), [NumPy](https://numpy.org/doc/stable/index.html),
@@ -91,7 +91,7 @@ Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#) [documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source. for more information on building the C++ and Python APIs from source.
## Contributing ## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the on contributing to MLX. See the
@@ -110,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following MLX useful in your research and wish to cite it, please use the following
BibTex entry: BibTex entry:
```text ```
@software{mlx2023, @software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert}, author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon}, title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
void time_irregular_binary_ops_4D() { void time_irregular_binary_ops_4D() {
auto device = mx::default_device(); auto device = mx::default_device();
mx::Shape shape = {8, 8, 512, 512}; std::vector<int> shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape); auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape); auto b = mx::random::uniform(shape);
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
void time_irregular_reshape() { void time_irregular_reshape() {
auto device = mx::default_device(); auto device = mx::default_device();
mx::Shape shape; std::vector<int> shape;
auto reshape_fn = [&shape, device](const mx::array& a) { auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device); return mx::reshape(a, shape, device);
}; };
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
void time_irregular_astype_2D() { void time_irregular_astype_2D() {
auto device = mx::default_device(); auto device = mx::default_device();
int size = 2048; int size = 2048;
mx::Shape shape = {size, size}; std::vector<int> shape = {size, size};
auto a = mx::random::uniform(shape); auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device); TIMEM("2D regular", mx::astype, a, mx::int32, device);

View File

@@ -142,7 +142,9 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype) c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4 atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -161,7 +163,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks") parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16", "complex64") dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn") transposes = ("nn", "nt", "tn")
shapes = ( shapes = (
(16, 234, 768, 3072), (16, 234, 768, 3072),
@@ -185,7 +187,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0 diff = gflops_mx / gflops_pt - 1.0
print( print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%" f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
) )
if gflops_pt >= 2.0 * gflops_mx: if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^") print("ATTENTION ^^^^^^^")

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse
import os import os
import subprocess import subprocess
import time import time
@@ -195,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True): for transpose in (False, True):
for dtype in ("float32", "float16", "complex64"): for dtype in ("float32", "float16"):
fig, axs = plt.subplots( fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
) )
@@ -214,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}") fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig( fig.savefig(
os.path.join( os.path.join(
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf" results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
) )
) )
plt.close(fig) plt.close(fig)

View File

@@ -1,212 +0,0 @@
import math
import os
import subprocess
import time
from copy import copy
from functools import partial
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from matplotlib.ticker import FuncFormatter
RESULTS_DIR = "./results"
if not os.path.isdir(RESULTS_DIR):
os.mkdir(RESULTS_DIR)
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
TORCH_DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
N_WARMUP = 5
N_ITER_BENCH = 50
N_ITER_FUNC = 20
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
D_TYPES = ("float32", "float16")
def _power_of_two_formatter(value, _position):
if value <= 0:
return ""
exponent = int(round(math.log2(value)))
if abs(value - (1 << exponent)) / value > 1e-6:
return f"{value:g}"
return f"$2^{{{exponent}}}$"
def torch_sync():
if TORCH_DEVICE.type == "cuda":
torch.cuda.synchronize()
elif TORCH_DEVICE.type == "mps":
torch.mps.synchronize()
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
outs = []
for _ in range(N_ITER_FUNC):
out = copy(self_arr)
out[mask_arr] = src_arr
outs.append(out)
mx.eval(outs)
return outs
@torch.no_grad()
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
outs = []
for _ in range(N_ITER_FUNC):
out = self_tensor.clone()
out.masked_scatter_(mask_tensor, src_tensor)
outs.append(out)
torch_sync()
return outs
def measure(fn):
for _ in range(N_WARMUP):
fn()
start = time.perf_counter_ns()
for _ in range(N_ITER_BENCH):
fn()
end = time.perf_counter_ns()
return (end - start) * 1e-9
def bytes_touched(length, true_count, item_size):
mask_bytes = length
self_bytes = length * item_size * 2 # read + write
src_bytes = true_count * item_size
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
def build_case(length, density, np_dtype, torch_dtype):
true_count = max(1, int(round(length * density)))
rng = np.random.default_rng()
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
mask_np = np.zeros(length, dtype=bool)
mask_np[:true_count] = True
rng.shuffle(mask_np)
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
self_mlx = mx.array(self_np)
mask_mlx = mx.array(mask_np)
src_mlx = mx.array(src_np)
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
# Correctness check once per configuration
mx_out = mx.array(self_np)
mx_out[mask_mlx] = src_mlx
mx.eval(mx_out)
torch_out = self_torch.clone()
torch_out.masked_scatter_(mask_torch, src_torch)
atol = 5e-3 if np_dtype == np.float16 else 1e-5
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
raise AssertionError("masked_scatter results diverged between MLX and Torch")
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
def bench_case(length, density, dtype):
np_dtype = getattr(np, dtype)
torch_dtype = getattr(torch, dtype)
(
self_mlx,
mask_mlx,
src_mlx,
self_torch,
mask_torch,
src_torch,
true_count,
) = build_case(length, density, np_dtype, torch_dtype)
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
time_torch = measure(
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
)
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
bytes_per_gb = float(1024**3)
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
return time_mlx, time_torch, mlx_gbps, torch_gbps
def plot_density(ax_perf, ax_speedup, density, dtype):
mlx_gbps = []
torch_gbps = []
mlx_times = []
torch_times = []
for length in VECTOR_LENGTHS:
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
mlx_gbps.append(gbps_mlx)
torch_gbps.append(gbps_torch)
mlx_times.append(t_mlx)
torch_times.append(t_torch)
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
ax_perf.set_xscale("log", base=2)
ax_perf.set_xticks(VECTOR_LENGTHS)
formatter = FuncFormatter(_power_of_two_formatter)
ax_perf.xaxis.set_major_formatter(formatter)
ax_perf.set_title(f"density={density:.2f}")
ax_perf.set_ylabel("GB/s")
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
ax_perf.legend()
speedup = np.array(torch_times) / np.array(mlx_times)
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
ax_speedup.set_xscale("log", base=2)
ax_speedup.set_xticks(VECTOR_LENGTHS)
ax_speedup.xaxis.set_major_formatter(formatter)
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
def main():
for dtype in D_TYPES:
fig, axs = plt.subplots(
len(MASK_DENSITIES),
2,
figsize=(10, 12),
layout="constrained",
sharex=True,
)
for i, density in enumerate(MASK_DENSITIES):
plot_density(axs[i][0], axs[i][1], density, dtype)
axs[i][0].set_xlabel("vector length")
axs[i][1].set_xlabel("vector length")
fig.suptitle(
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
)
output_path = os.path.join(
RESULTS_DIR,
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
)
fig.savefig(output_path)
plt.close(fig)
if __name__ == "__main__":
main()

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -1,3 +0,0 @@
# This file does nothing but to suppress the cmake warning: "By not providing
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.

View File

@@ -1,5 +1,4 @@
sphinx sphinx
breathe breathe
sphinx-book-theme sphinx-book-theme
sphinx-copybutton
mlx mlx

View File

@@ -18,7 +18,6 @@ release = version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
extensions = [ extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc", "sphinx.ext.autodoc",
"sphinx.ext.autosummary", "sphinx.ext.autosummary",
"sphinx.ext.intersphinx", "sphinx.ext.intersphinx",

View File

@@ -127,8 +127,7 @@ relying on a copy from ``ensure_row_contiguous``:
name="myexp_strided", name="myexp_strided",
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source, source=source
ensure_row_contiguous=False,
) )
def exp_elementwise(a: mx.array): def exp_elementwise(a: mx.array):
@@ -139,6 +138,7 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes=[a.shape], output_shapes=[a.shape],
output_dtypes=[a.dtype], output_dtypes=[a.dtype],
ensure_row_contiguous=False,
) )
return outputs[0] return outputs[0]

View File

@@ -70,7 +70,6 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/cuda
python/memory_management python/memory_management
python/nn python/nn
python/optimizers python/optimizers

View File

@@ -16,11 +16,12 @@ silicon computer is
To install from PyPI your system must meet the following requirements: To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon) - Using an M series chip (Apple silicon)
- Using a native Python >= 3.10 - Using a native Python >= 3.9
- macOS >= 14.0 - macOS >= 13.5
.. note:: .. note::
MLX is only available on devices running macOS >= 14.0 and higher. MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
CUDA CUDA
^^^^ ^^^^
@@ -38,7 +39,7 @@ requirements:
- Nvidia driver >= 550.54.14 - Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0 - CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35 - Linux distribution with glibc >= 2.35
- Python >= 3.10 - Python >= 3.9
CPU-only (Linux) CPU-only (Linux)
@@ -54,7 +55,7 @@ To install the CPU-only package from PyPi your system must meet the following
requirements: requirements:
- Linux distribution with glibc >= 2.35 - Linux distribution with glibc >= 2.35
- Python >= 3.10 - Python >= 3.9
Troubleshooting Troubleshooting
@@ -270,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y apt-get update -y
apt-get -y install cuda-toolkit-12-9 apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag When building either the Python or C++ APIs make sure to pass the cmake flag

View File

@@ -1,9 +0,0 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,4 +13,3 @@ Fast
rope rope
scaled_dot_product_attention scaled_dot_product_attention
metal_kernel metal_kernel
cuda_kernel

View File

@@ -27,7 +27,6 @@ simple functions.
mish mish
prelu prelu
relu relu
relu2
relu6 relu6
selu selu
sigmoid sigmoid

View File

@@ -50,7 +50,6 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU2
ReLU6 ReLU6
RNN RNN
RoPE RoPE

View File

@@ -112,7 +112,6 @@ Operations
max max
maximum maximum
mean mean
median
meshgrid meshgrid
min min
minimum minimum

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state, destination={}) state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", state) mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors")) state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python .. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096)) x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x) timeit(nn.gelu, x)
timeit(mx.compile(gelu), x) timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
def fun(x, y): def fun(x, y):
z = x + y z = x + y
state.append(z) state.append(z)
return mx.exp(z) return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0)) fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)] # Prints [array(3, dtype=float32)]

View File

@@ -7,13 +7,12 @@ Distributed Communication
MLX supports distributed communication operations that allow the computational cost MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the of training or inference to be shared across many physical machines. At the
moment we support three different communication backends: moment we support two different communication backends:
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a * `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets. It should be * A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections, but it also works over Ethernet. faster for thunderbolt connections.
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
The list of all currently supported operations and their documentation can be The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`. seen in the :ref:`API docs<distributed>`.
@@ -85,8 +84,9 @@ Selecting Backend
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
available backends. If they all fail then a singleton group is created. initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
both fail then a singleton group is created.
.. note:: .. note::
After a distributed backend is successfully initialized :func:`init` will After a distributed backend is successfully initialized :func:`init` will
@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y): def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y) loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss
@@ -220,7 +220,7 @@ print 4 etc.
Installing MPI Installing MPI
^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows: with the Anaconda package manager as follows:
@@ -228,16 +228,14 @@ with the Anaconda package manager as follows:
$ conda install conda-forge::openmpi $ conda install conda-forge::openmpi
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld`` Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``. Some environments use a non-standard done automatically by ``mlx.launch``.
library filename that can be specified using the ``MPI_LIBNAME`` environment
variable. This is automatically taken care of by ``mlx.launch`` as well.
.. code:: shell .. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ # or simply $ # or simply
$ mlx.launch -n 2 test.py $ mlx.launch -n 2 test.py

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++). front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples. This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation To see the full list of functions check-out the :ref:`API documentation
<export>`. <export>`.
Basics of Exporting Basics of Exporting
------------------- -------------------
Let's start with a simple example: Let's start with a simple example:
.. code-block:: python .. code-block:: python
def fun(x, y): def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0) x = mx.array(1.0)
y = mx.array(1.0) y = mx.array(1.0)
# Both arguments to fun are positional # Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y) mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs. the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters. :obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters # Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items()))) model.update(tree_unflatten(list(params.items())))
return model(x) return model(x)
params = tree_flatten(model.parameters(), destination={}) params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -164,13 +164,13 @@ to export a function which can be used for inputs with variable shapes:
.. code-block:: python .. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True) mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn") imported_abs = mx.import_function("fun.mlxfn")
# Ok # Ok
out, = imported_abs(mx.array([-1.0])) out, = imported_abs(mx.array(-1.0))
# Also ok # Also ok
out, = imported_abs(mx.array([-1.0, -2.0])) out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None): def fun(x, y=None):
constant = mx.array(3.0) constant = mx.array(3.0)
if y is not None: if y is not None:
x += y x += y
return x + constant return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter: with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out) print(out)
In the above example the function constant data, (i.e. ``constant``), is only In the above example the function constant data, (i.e. ``constant``), is only
saved once. saved once.
Transformations with Imported Functions Transformations with Imported Functions
--------------------------------------- ---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32) # Prints: array(1, dtype=float32)
print(dfdx(x)) print(dfdx(x))
# Compile the imported function # Compile the imported function
mx.compile(imported_fun) mx.compile(imported_fun)
# Prints: array(0, dtype=float32) # Prints: array(0, dtype=float32)
print(compiled_fun(x)[0]) print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32) // Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl; std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string, ``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++. mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -70,8 +70,7 @@ Differences from NumPy
* Indexing does not perform bounds checking. Indexing out of bounds is * Indexing does not perform bounds checking. Indexing out of bounds is
undefined behavior. undefined behavior.
* Boolean mask based indexing is supported for assignment only (see * Boolean mask based indexing is not yet supported.
:ref:`boolean-mask-assignment`).
The reason for the lack of bounds checking is that exceptions cannot propagate The reason for the lack of bounds checking is that exceptions cannot propagate
from the GPU. Performing bounds checking for array indices before launching the from the GPU. Performing bounds checking for array indices before launching the
@@ -108,20 +107,8 @@ same array:
>>> a >>> a
array([1, 2, 0], dtype=int32) array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
.. code-block:: shell Note, unlike NumPy, updates to the same location are nondeterministic:
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell .. code-block:: shell
@@ -144,51 +131,3 @@ expected. For example:
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx`` In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
and ones elsewhere. and ones elsewhere.
.. _boolean-mask-assignment:
Boolean Mask Assignment
-----------------------
MLX supports boolean indices using NumPy syntax. A mask must already be
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
Other index types are routed through the standard scatter code.
.. code-block:: shell
>>> a = mx.array([1.0, 2.0, 3.0])
>>> mask = mx.array([True, False, True])
>>> updates = mx.array([5.0, 6.0])
>>> a[mask] = updates
>>> a
array([5.0, 2.0, 6.0], dtype=float32)
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
assignments, ``updates`` must provide at least as many elements as there are
``True`` entries in ``mask``.
.. code-block:: shell
>>> a = mx.zeros((2, 3))
>>> mask = mx.array([[True, False, True],
[False, False, True]])
>>> a[mask] = 1.0
>>> a
array([[1.0, 0.0, 1.0],
[0.0, 0.0, 1.0]], dtype=float32)
Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. The only
exception is a scalar boolean mask, which broadcasts to the full array.
- Any axes not covered by the mask are taken in full.
.. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
axes and therefore raise errors.

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_; void* ptr_;
public: public:
explicit Buffer(void* ptr) : ptr_(ptr) {}; Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer // Get the raw data pointer from the buffer
void* raw_ptr(); void* raw_ptr();

View File

@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.strides(), other.strides(),
other.flags(), other.flags(),
[](auto) {}); [](auto) {});
cpy.array_desc_->offset = other.array_desc_->offset; cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy; return cpy;
} }
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
void array::set_data(allocator::Buffer buffer, Deleter d) { void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->offset = 0; array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size(); array_desc_->data_size = size();
array_desc_->flags.contiguous = true; array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true; array_desc_->flags.row_contiguous = true;
@@ -156,7 +156,7 @@ void array::set_data(
Flags flags, Flags flags,
Deleter d) { Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->offset = 0; array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides); array_desc_->strides = std::move(strides);
array_desc_->flags = flags; array_desc_->flags = flags;
@@ -167,13 +167,14 @@ void array::copy_shared_buffer(
const Strides& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
int64_t offset /* = 0 */) { size_t offset /* = 0 */) {
array_desc_->data = other.array_desc_->data; array_desc_->data = other.array_desc_->data;
array_desc_->strides = strides; array_desc_->strides = strides;
array_desc_->flags = flags; array_desc_->flags = flags;
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
array_desc_->offset = auto char_offset = sizeof(char) * itemsize() * offset;
sizeof(char) * itemsize() * offset + other.array_desc_->offset; array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
} }
void array::copy_shared_buffer(const array& other) { void array::copy_shared_buffer(const array& other) {
@@ -240,8 +241,8 @@ array::ArrayDesc::ArrayDesc(
std::vector<array> inputs) std::vector<array> inputs)
: shape(std::move(shape)), : shape(std::move(shape)),
dtype(dtype), dtype(dtype),
primitive(std::move(primitive)),
status(Status::unscheduled), status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) { inputs(std::move(inputs)) {
init(); init();
} }

View File

@@ -294,11 +294,6 @@ class array {
return array_desc_->siblings; return array_desc_->siblings;
} }
/** The array's position in the sibling list. */
int sibling_position() const {
return array_desc_->position;
}
void set_siblings(std::vector<array> siblings, uint16_t position) { void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings); array_desc_->siblings = std::move(siblings);
array_desc_->position = position; array_desc_->position = position;
@@ -354,23 +349,15 @@ class array {
return array_desc_->data; return array_desc_->data;
} }
// Return a raw pointer to the arrays data. This function may do a copy if // Return a raw pointer to the arrays data
// the underlying buffer is not accessible on the CPU. When accessing the
// data for GPU kernels, be sure to use the correct method / function for the
// given backend to access the GPU pointer.
template <typename T> template <typename T>
T* data() { T* data() {
return reinterpret_cast<T*>( return static_cast<T*>(array_desc_->data_ptr);
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
} }
template <typename T> template <typename T>
const T* data() const { const T* data() const {
return const_cast<array&>(*this).data<T>(); return static_cast<T*>(array_desc_->data_ptr);
}
int64_t offset() const {
return array_desc_->offset;
} }
enum Status { enum Status {
@@ -439,7 +426,7 @@ class array {
const Strides& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
int64_t offset = 0); size_t offset = 0);
void copy_shared_buffer(const array& other); void copy_shared_buffer(const array& other);
@@ -474,8 +461,8 @@ class array {
// can share the underlying data buffer. // can share the underlying data buffer.
std::shared_ptr<Data> data; std::shared_ptr<Data> data;
// Offset from beginning of data pointer // Properly offset data pointer
int64_t offset{0}; void* data_ptr{nullptr};
// The size in elements of the data buffer the array accesses // The size in elements of the data buffer the array accesses
size_t data_size; size_t data_size;

View File

@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
const array& a, const array& a,
const array& b, const array& b,
array& out, array& out,
BinaryOpType bopt, BinaryOpType bopt) {
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
bool b_donatable = is_donatable(b, out); bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out); bool a_donatable = is_donatable(a, out);
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags()); out.set_data(
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b_donatable) { if (b_donatable) {
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data( out.set_data(
mallocfn(b.data_size() * out.itemsize()), allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} else { } else {
out.set_data( out.set_data(
mallocfn(a.data_size() * out.itemsize()), allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data( out.set_data(
mallocfn(a.data_size() * out.itemsize()), allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) { b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data(mallocfn(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
break; break;
} }

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) { void broadcast(const array& in, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(allocator::malloc(0)); out.set_data(nullptr);
return; return;
} }
Strides strides(out.ndim(), 0); Strides strides(out.ndim(), 0);

View File

@@ -114,9 +114,7 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant, const std::function<bool(size_t)>& is_constant,
bool contiguous, bool contiguous) {
const std::function<allocator::Buffer(size_t)>&
mallocfn /* = allocator::malloc */) {
if (contiguous) { if (contiguous) {
int o = 0; int o = 0;
Strides strides; Strides strides;
@@ -142,7 +140,7 @@ void compiled_allocate_outputs(
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data( outputs[o].set_data(
mallocfn(data_size * outputs[o].itemsize()), allocator::malloc(data_size * outputs[o].itemsize()),
data_size, data_size,
strides, strides,
flags); flags);
@@ -165,7 +163,7 @@ void compiled_allocate_outputs(
} }
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data(mallocfn(outputs[o].nbytes())); outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
} }
} }
} }

View File

@@ -58,9 +58,7 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant, const std::function<bool(size_t)>& is_constant,
bool contiguous, bool contiguous);
const std::function<allocator::Buffer(size_t)>& mallocfn =
allocator::malloc);
// Collapse contiguous dims ignoring scalars and constants. // Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims( std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(

View File

@@ -22,11 +22,7 @@ enum class CopyType {
GeneralGeneral GeneralGeneral
}; };
inline bool set_copy_output_data( inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types // If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output. // have the same size, then the input buffer can hold the output.
@@ -35,14 +31,14 @@ inline bool set_copy_output_data(
return true; return true;
} else { } else {
out.set_data( out.set_data(
mallocfn(in.data_size() * out.itemsize()), allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());
return false; return false;
} }
} else { } else {
out.set_data(mallocfn(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
return false; return false;
} }
} }

View File

@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a, const array& a,
const array& b) { const array& b) {
if (a.ndim() == 2) { if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}}; return {{1}, {0}, {0}};
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides> inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) { collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) { if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}}; return {{1}, {0}, {0}, {0}};
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Shape A_bshape{a.shape().begin(), a.shape().end() - 2};

View File

@@ -14,13 +14,17 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i]; data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i]; inp_strides[i] = in.strides()[i] * strides[i];
} }
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides); return std::make_tuple(data_offset, inp_strides);
} }
void shared_buffer_slice( void shared_buffer_slice(
const array& in, const array& in,
const Strides& out_strides, const Strides& out_strides,
int64_t data_offset, size_t data_offset,
size_t data_size, size_t data_size,
array& out) { array& out) {
// Compute row/col contiguity // Compute row/col contiguity
@@ -41,30 +45,23 @@ void slice(
const Shape& start_indices, const Shape& start_indices,
const Shape& strides) { const Shape& strides) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(allocator::malloc(0)); out.set_data(nullptr);
return; return;
} }
// Calculate out strides, initial offset // Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides); auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
int64_t data_end = 1;
// Get the location of the end based on the inp strides and out.shape() for (int i = 0; i < start_indices.size(); ++i) {
int64_t low_idx = 0; if (in.shape()[i] > 1) {
int64_t high_idx = 0; auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
for (int i = 0; i < inp_strides.size(); ++i) { data_end += end_idx * in.strides()[i];
auto delta = inp_strides[i] * (out.shape()[i] - 1);
if (inp_strides[i] > 0) {
high_idx += delta;
} else {
low_idx += delta;
} }
} }
int64_t data_size = (high_idx - low_idx) + 1; if (data_end < 0) {
if (data_size < 0) { data_end += in.data_size();
std::ostringstream msg;
msg << "[slice] Computed invalid data size: " << data_size << ".";
throw std::runtime_error(msg.str());
} }
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out); shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
} }

View File

@@ -11,8 +11,6 @@ namespace mlx::core {
enum class TernaryOpType { enum class TernaryOpType {
ScalarScalarScalar, ScalarScalarScalar,
VectorVectorVector, VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General, General,
}; };
@@ -27,14 +25,6 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous && (a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) { c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector; topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else { } else {
topt = TernaryOpType::General; topt = TernaryOpType::General;
} }
@@ -46,8 +36,7 @@ inline void set_ternary_op_output_data(
const array& b, const array& b,
const array& c, const array& c,
array& out, array& out,
TernaryOpType topt, TernaryOpType topt) {
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
auto maybe_donate = [&out](const array& x) { auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) { if (is_donatable(x, out)) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
@@ -58,25 +47,24 @@ inline void set_ternary_op_output_data(
switch (topt) { switch (topt) {
case TernaryOpType::ScalarScalarScalar: case TernaryOpType::ScalarScalarScalar:
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags()); out.set_data(
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break; break;
case TernaryOpType::VectorVectorVector: case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data( out.set_data(
mallocfn(out.itemsize() * b.data_size()), allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
} }
break; break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General: case TernaryOpType::General:
// Try to donate an input which is row_contiguous // Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) || if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) || (b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) { (c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(mallocfn(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
break; break;
} }

View File

@@ -7,22 +7,19 @@
namespace mlx::core { namespace mlx::core {
inline void set_unary_output_data( inline void set_unary_output_data(const array& in, array& out) {
const array& in,
array& out,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (in.flags().contiguous) { if (in.flags().contiguous) {
if (is_donatable(in, out)) { if (is_donatable(in, out)) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
out.set_data( out.set_data(
mallocfn(in.data_size() * out.itemsize()), allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());
} }
} else { } else {
out.set_data(mallocfn(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
} }

View File

@@ -228,4 +228,31 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
} }
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -196,6 +196,9 @@ void shared_buffer_reshape(
const Strides& out_strides, const Strides& out_strides,
array& out); array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T> template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) { inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index)); vec.erase(std::next(vec.begin(), index));

View File

@@ -14,11 +14,233 @@
namespace mlx::core { namespace mlx::core {
namespace {
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace
void Add::eval_cpu(const std::vector<array>& inputs, array& out) { void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Add(), stream()); binary(a, b, out, detail::Add(), stream());
} }
void DivMod::eval_cpu( void DivMod::eval_cpu(
@@ -102,14 +324,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Divide(), stream()); binary(a, b, out, detail::Divide(), stream());
} }
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) { void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Remainder(), stream()); binary(a, b, out, detail::Remainder(), stream());
} }
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) { void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -150,90 +372,89 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
}); });
} else { } else {
comparison_op_cpu(a, b, out, detail::Equal(), stream()); comparison_op(a, b, out, detail::Equal(), stream());
} }
} }
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) { void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream()); comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
} }
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op_cpu( comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
} }
void Less::eval_cpu(const std::vector<array>& inputs, array& out) { void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream()); comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
} }
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream()); comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
} }
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) { void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream()); binary_float(a, b, out, detail::LogAddExp(), stream());
} }
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream()); binary(in1, in2, out, detail::LogicalAnd(), stream());
} }
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream()); binary(in1, in2, out, detail::LogicalOr(), stream());
} }
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) { void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Maximum(), stream()); binary(a, b, out, detail::Maximum(), stream());
} }
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) { void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Minimum(), stream()); binary(a, b, out, detail::Minimum(), stream());
} }
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) { void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Multiply(), stream()); binary(a, b, out, detail::Multiply(), stream());
} }
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream()); comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
} }
void Power::eval_cpu(const std::vector<array>& inputs, array& out) { void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Power(), stream()); binary(a, b, out, detail::Power(), stream());
} }
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) { void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_op_cpu(a, b, out, detail::Subtract(), stream()); binary(a, b, out, detail::Subtract(), stream());
} }
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) { void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -242,19 +463,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
switch (op_) { switch (op_) {
case BitwiseBinary::And: case BitwiseBinary::And:
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream()); binary_int(a, b, out, detail::BitwiseAnd(), stream());
break; break;
case BitwiseBinary::Or: case BitwiseBinary::Or:
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream()); binary_int(a, b, out, detail::BitwiseOr(), stream());
break; break;
case BitwiseBinary::Xor: case BitwiseBinary::Xor:
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream()); binary_int(a, b, out, detail::BitwiseXor(), stream());
break; break;
case BitwiseBinary::LeftShift: case BitwiseBinary::LeftShift:
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream()); binary_int(a, b, out, detail::LeftShift(), stream());
break; break;
case BitwiseBinary::RightShift: case BitwiseBinary::RightShift:
binary_int_op_cpu(a, b, out, detail::RightShift(), stream()); binary_int(a, b, out, detail::RightShift(), stream());
break; break;
} }
} }
@@ -263,7 +484,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
const auto& a = inputs[0]; const auto& a = inputs[0];
const auto& b = inputs[1]; const auto& b = inputs[1];
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream()); binary_float(a, b, out, detail::ArcTan2(), stream());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -7,7 +7,6 @@
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core { namespace mlx::core {
@@ -291,227 +290,4 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt); binary_op<T, T, Op>(a, b, out, bopt);
} }
template <typename Op>
void binary_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -15,7 +15,6 @@
#include "mlx/backend/cpu/jit_compiler.h" #include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/graph_utils.h" #include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core { namespace mlx::core {
@@ -95,11 +94,7 @@ void* compile(
kernel_file_name = kernel_name; kernel_file_name = kernel_name;
} }
auto output_dir = auto output_dir = std::filesystem::temp_directory_path();
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
std::string shared_lib_name = "lib" + kernel_file_name + ".so"; std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string(); auto shared_lib_path = (output_dir / shared_lib_name).string();
@@ -162,12 +157,10 @@ inline void build_kernel(
#endif #endif
// Start the kernel // Start the kernel
os << "void " << kernel_name os << "void " << kernel_name << "(void** args) {" << std::endl;
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
// Add the input arguments // Add the input arguments
int cnt = 0; int cnt = 0;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list // Skip constants from the input list
if (is_constant(i)) { if (is_constant(i)) {
@@ -182,8 +175,8 @@ inline void build_kernel(
<< "];" << std::endl; << "];" << std::endl;
// Scalars and contiguous need no strides // Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) { if (!is_scalar(x) && !contiguous) {
os << " const int64_t* " << xname << "_strides = strides[" os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< strides_index++ << "];" << std::endl; << "];" << std::endl;
} }
} }
@@ -193,8 +186,10 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl; << "*)args[" << cnt++ << "];" << std::endl;
} }
// Add output size // Add output strides and shape to extract the indices.
if (contiguous) { if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
} }
@@ -293,8 +288,17 @@ void Compiled::eval_cpu(
auto [contiguous, shape, strides] = auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Force allocating shape/strides on heap so we can take their data() first
// and then std::move them.
// TODO: Refactor code to avoid heap allocation.
shape.grow();
for (auto& s : strides) {
s.grow();
}
// Collect function input arguments. // Collect function input arguments.
std::vector<void*> args; std::vector<void*> args;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) { if (is_constant_(i)) {
continue; continue;
@@ -302,6 +306,9 @@ void Compiled::eval_cpu(
const auto& x = inputs[i]; const auto& x = inputs[i];
encoder.set_input_array(x); encoder.set_input_array(x);
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
args.push_back(strides[strides_index++].data());
}
} }
// Get the kernel name from the lib // Get the kernel name from the lib
@@ -336,20 +343,16 @@ void Compiled::eval_cpu(
args.push_back(x.data<void>()); args.push_back(x.data<void>());
encoder.set_output_array(x); encoder.set_output_array(x);
} }
if (contiguous) { if (!contiguous) {
args.push_back((void*)shape.data());
} else {
args.push_back((void*)outputs[0].data_size()); args.push_back((void*)outputs[0].data_size());
} }
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr); auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch([fun, encoder.dispatch([fun,
args = std::move(args), args = std::move(args),
strides = std::move(strides), strides = std::move(strides),
shape = std::move(shape)]() mutable { shape = std::move(shape)]() mutable { fun(args.data()); });
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -996,6 +996,131 @@ void explicit_gemm_conv_1D_cpu(
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
} }
void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
auto conv_dtype = out.dtype();
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
temps.push_back(in_padded_slice);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
encoder.set_input_array(in_strided);
encoder.set_input_array(gemm_wt);
encoder.set_output_array(gemm_out);
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
gemm_wt_ptr = gemm_wt.data<float>(),
gemm_out_ptr = gemm_out.data<float>(),
strided_reshape = std::move(strided_reshape),
O]() {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided_ptr,
strided_reshape[1], // lda
gemm_wt_ptr,
strided_reshape[1], // ldb
0.0f, // beta
gemm_out_ptr,
O // ldc
);
});
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_ND_cpu( void explicit_gemm_conv_ND_cpu(
const array& in, const array& in,
const array& wt, const array& wt,

View File

@@ -95,9 +95,4 @@ void Recv::eval_cpu(
distributed::detail::recv(group(), outputs[0], src_, stream()); distributed::detail::recv(group(), outputs[0], src_, stream());
} }
void ReduceScatter::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
}
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

View File

@@ -12,167 +12,6 @@ namespace mlx::core {
namespace { namespace {
template <typename T>
complex64_t to_complex(T r, T i) {
return {static_cast<float>(r), static_cast<float>(i)};
}
template <typename T, class Enable = void>
struct EigWork {};
template <typename T>
struct EigWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using O = complex64_t;
char jobl;
char jobr;
int N;
int lwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
T work;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
if (compute_eigenvectors) {
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
}
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, O* values, O* vectors) {
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
T* vec_tmp = nullptr;
if (vectors) {
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
}
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
eig_tmp,
eig_tmp + N,
vectors ? vec_tmp : nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
work,
&lwork,
&info);
for (int i = 0; i < N; ++i) {
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
}
if (vectors) {
for (int i = 0; i < N; ++i) {
if (values[i].imag() != 0) {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] =
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
vectors[(i + 1) * N + j] =
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
}
}
}
}
}
};
template <>
struct EigWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
using O = T;
char jobl;
char jobr;
int N;
int lwork;
int lrwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
T work;
R rwork;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&rwork,
&info);
lwork = static_cast<int>(work.real());
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
}
void run(T* a, T* values, T* vectors) {
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
values,
vectors,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&info);
}
};
template <typename T> template <typename T>
void eig_impl( void eig_impl(
array& a, array& a,
@@ -180,39 +19,102 @@ void eig_impl(
array& values, array& values,
bool compute_eigenvectors, bool compute_eigenvectors,
Stream stream) { Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>(); auto a_ptr = a.data<T>();
auto val_ptr = values.data<complex64_t>(); auto eig_ptr = values.data<OT>();
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_output_array(values); encoder.set_output_array(values);
complex64_t* vec_ptr = nullptr; OT* vec_ptr = nullptr;
if (compute_eigenvectors) { if (compute_eigenvectors) {
encoder.set_output_array(vectors); encoder.set_output_array(vectors);
vec_ptr = vectors.data<complex64_t>(); vec_ptr = vectors.data<OT>();
} }
encoder.dispatch([a_ptr, encoder.dispatch([a_ptr,
val_ptr,
vec_ptr, vec_ptr,
eig_ptr,
compute_eigenvectors, compute_eigenvectors,
N = vectors.shape(-1), N = vectors.shape(-1),
size = vectors.size()]() mutable { size = vectors.size()]() mutable {
// Work query
char jobr = 'N'; char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N'; char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
EigWork<T> work(jobl, jobr, N, compute_eigenvectors); auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (size_t i = 0; i < size / (N * N); ++i) { for (size_t i = 0; i < size / (N * N); ++i) {
work.run(a_ptr, val_ptr, vec_ptr); geev<T>(
a_ptr += N * N; &jobl,
val_ptr += N; &jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
if (vec_ptr) { if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N; vec_ptr += N * N;
} }
if (work.info != 0) { a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg; std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code " msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< work.info; << info;
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
} }
@@ -264,17 +166,8 @@ void Eig::eval_cpu(
case float32: case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream()); eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break; break;
case float64:
eig_impl<double>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case complex64:
eig_impl<std::complex<float>>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default: default:
throw std::runtime_error( throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
} }
} }

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
#include "mlx/array.h" #include "mlx/array.h"
@@ -48,15 +49,9 @@ void matmul_bnns(
size_t K = a_shape[ndim - 1]; size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype<T>(); BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
}
const BNNSLayerParametersBroadcastMatMul gemm_params{ const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha, /* float alpha = */ alpha,
/* float beta = */ beta, /* float beta = */ beta,

View File

@@ -88,47 +88,4 @@ void matmul<double>(
} }
} }
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -747,108 +747,4 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
}); });
} }
template <typename T>
void masked_scatter_impl(const array& mask, const array& src, array& out) {
ContiguousIterator mask_it(mask);
ContiguousIterator src_it(src);
ContiguousIterator out_it(out);
const bool* mask_ptr = mask.data<bool>();
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
const size_t batch_count = mask.shape(0);
const size_t mask_batch_size = mask.size() / batch_count;
const size_t src_batch_size = src.size() / batch_count;
for (uint b = 0; b < batch_count; ++b) {
size_t src_consumed = 0;
src_it.seek(b * src_batch_size);
for (size_t i = 0; i < mask_batch_size; ++i) {
if (mask_ptr[mask_it.loc]) {
if (src_consumed >= src_batch_size) {
throw std::runtime_error(
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
}
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
src_it.step();
++src_consumed;
}
mask_it.step();
out_it.step();
}
}
}
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& dst = inputs[0];
auto& mask = inputs[1];
auto& src = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(dst, out, ctype, stream());
if (mask.size() == 0) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(mask);
encoder.set_input_array(src);
encoder.set_output_array(out);
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
src = array::unsafe_weak_copy(src),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
masked_scatter_impl<bool>(mask, src, out);
break;
case uint8:
masked_scatter_impl<uint8_t>(mask, src, out);
break;
case uint16:
masked_scatter_impl<uint16_t>(mask, src, out);
break;
case uint32:
masked_scatter_impl<uint32_t>(mask, src, out);
break;
case uint64:
masked_scatter_impl<uint64_t>(mask, src, out);
break;
case int8:
masked_scatter_impl<int8_t>(mask, src, out);
break;
case int16:
masked_scatter_impl<int16_t>(mask, src, out);
break;
case int32:
masked_scatter_impl<int32_t>(mask, src, out);
break;
case int64:
masked_scatter_impl<int64_t>(mask, src, out);
break;
case float16:
masked_scatter_impl<float16_t>(mask, src, out);
break;
case float32:
masked_scatter_impl<float>(mask, src, out);
break;
case float64:
masked_scatter_impl<double>(mask, src, out);
break;
case bfloat16:
masked_scatter_impl<bfloat16_t>(mask, src, out);
break;
case complex64:
masked_scatter_impl<complex64_t>(mask, src, out);
break;
}
});
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -45,7 +45,9 @@
INSTANTIATE_LAPACK_REAL(geqrf) INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr) INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf) INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesvdx)
INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri) INSTANTIATE_LAPACK_REAL(trtri)
@@ -61,20 +63,3 @@ INSTANTIATE_LAPACK_REAL(trtri)
} }
INSTANTIATE_LAPACK_COMPLEX(heevd) INSTANTIATE_LAPACK_COMPLEX(heevd)
#define INSTANTIATE_LAPACK_ALL(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, double>) { \
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_ALL(geev)
INSTANTIATE_LAPACK_ALL(gesdd)

View File

@@ -215,18 +215,18 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
const void* a_mask_ptr = nullptr; const void* a_mask_ptr;
const void* b_mask_ptr = nullptr; const void* b_mask_ptr;
const void* out_mask_ptr = nullptr; const void* out_mask_ptr;
Shape a_mask_shape; Shape a_mask_shape;
Shape b_mask_shape; Shape b_mask_shape;
Shape out_mask_shape; Shape out_mask_shape;
Strides a_mask_strides; Strides a_mask_strides;
Strides b_mask_strides; Strides b_mask_strides;
Strides out_mask_strides; Strides out_mask_strides;
bool a_mask_bool = false; bool a_mask_bool;
bool b_mask_bool = false; bool b_mask_bool;
bool out_mask_bool = false; bool out_mask_bool;
if (has_op_mask) { if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2]; auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1]; auto& b_mask = inputs[inputs.size() - 1];
@@ -423,6 +423,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& rhs_indices = inputs[3]; auto& rhs_indices = inputs[3];
auto batch_shape = get_batch_dims(out.shape()); auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
auto batch_shape_A = get_batch_dims(a.shape()); auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides()); auto batch_strides_A = get_batch_dims(a.strides());

View File

@@ -2,8 +2,6 @@
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/gemm.h"
@@ -93,6 +91,7 @@ void matmul_general(
auto [b_transposed, ldb, b] = check_transpose(b_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2); size_t M = a.shape(-2);
size_t N = b.shape(-1); size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) { if (M == 0 || N == 0) {
return; return;
} }
@@ -109,9 +108,6 @@ void matmul_general(
} else if (out.dtype() == float64) { } else if (out.dtype() == float64) {
matmul_dispatch<double>( matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else { } else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
} }
@@ -132,34 +128,24 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
return; return;
} }
// Handle empty matrix case (K=0)
if (inputs[0].shape(-1) == 0) {
auto& c = inputs[2];
if (beta_ == 1.0f) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
} else {
array beta_scalar = array(beta_, c.dtype());
auto& encoder = cpu::get_command_encoder(stream());
binary_float_op_cpu(c, beta_scalar, out, detail::Multiply(), stream());
encoder.add_temporary(std::move(beta_scalar));
}
return;
}
// Fill output with C // Fill output with C
auto& c = inputs[2]; auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 CopyType ctype = c.data_size() == 1
? CopyType::Scalar ? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream()); copy_cpu(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
} }

View File

@@ -333,7 +333,7 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) { void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(allocator::malloc(0)); out.set_data(nullptr);
return; return;
} }
auto& in = inputs[0]; auto& in = inputs[0];
@@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out) { array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(allocator::malloc(0)); out.set_data(nullptr);
return; return;
} }
@@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) { void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(allocator::malloc(0)); out.set_data(nullptr);
return; return;
} }

View File

@@ -1,11 +1,10 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h" #include <cassert>
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -14,35 +13,6 @@ namespace mlx::core {
namespace { namespace {
const static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
inline constexpr short get_pack_factor(int bits, int wsize = 8) { inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
} }
@@ -437,229 +407,6 @@ void _qmm_dispatch(
} }
} }
template <typename T>
void mxfp4_qmm(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
for (int ng = 0; ng < packs_in_group; ng++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
}
}
result += N;
}
}
template <typename T>
void mxfp4_qmm_t(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
T gsum = 0;
for (int kw = 0; kw < packs_in_group; kw++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
sum += scale * gsum;
}
*result = sum;
result++;
}
x += K;
}
}
template <int S>
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
if constexpr (S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & 0xf;
simd::Simd<float, S> w_out;
for (int i = 0; i < S; ++i) {
w_out[i] = MXFP4_LUT[wi[i]];
}
return w_out;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
}
template <typename T>
void mxfp4_qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = 32 / 4;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
simd::Simd<float, S> g_acc(0);
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
// Extract bits
auto wf = mxfp4_extract_bits_simd<S>(w_local);
w_local += packs_per_simd;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
g_acc = g_acc + x_simd * wf;
x_local += S;
}
acc = acc + scale * g_acc;
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T>
void mxfp4_qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (simd::max_size<T> % 8 == 0) {
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
} else {
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
}
} else {
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
}
}
template <typename T>
void mxfp4_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
for (int i = 0; i < batch_size; i++) {
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
switch (x.dtype()) {
case bfloat16:
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
break;
case float16:
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
break;
case float32:
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T> template <typename T>
void _bs_qmm_dispatch_typed( void _bs_qmm_dispatch_typed(
array& out, array& out,
@@ -766,198 +513,115 @@ void _bs_qmm_dispatch(
} }
} }
template <typename T>
void mxfp4_bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
switch (x.dtype()) {
case float32:
mxfp4_bs_qmm_dispatch_typed<float>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case float16:
mxfp4_bs_qmm_dispatch_typed<float16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case bfloat16:
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace } // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
auto& scales_pre = inputs[2]; auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& encoder = cpu::get_command_encoder(stream()); std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) { auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return arr; return arr;
} else { } else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, arr_cpy, CopyType::General, s); copy_cpu(arr, temps.back(), CopyType::General, s);
encoder.add_temporary(arr_cpy); return temps.back();
return arr_cpy;
} }
}; };
auto x = ensure_row_contiguous(x_pre); auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre); auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre); auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(scales); encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out); encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) { encoder.dispatch([out = array::unsafe_weak_copy(out),
auto biases = ensure_row_contiguous(inputs[3]); x = array::unsafe_weak_copy(x),
encoder.set_input_array(biases); w = array::unsafe_weak_copy(w),
encoder.dispatch([out = array::unsafe_weak_copy(out), scales = array::unsafe_weak_copy(scales),
x = array::unsafe_weak_copy(x), biases = array::unsafe_weak_copy(biases),
w = array::unsafe_weak_copy(w), group_size_ = group_size_,
scales = array::unsafe_weak_copy(scales), bits_ = bits_,
biases = array::unsafe_weak_copy(biases), transpose_ = transpose_]() mutable {
group_size_ = group_size_, _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
bits_ = bits_, });
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
transpose_ = transpose_]() mutable {
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
});
}
} }
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
auto& scales_pre = inputs[2]; auto& scales_pre = inputs[2];
auto& lhs_indices = inputs[inputs.size() - 2]; auto& biases_pre = inputs[3];
auto& rhs_indices = inputs[inputs.size() - 1]; auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto& encoder = cpu::get_command_encoder(stream()); std::vector<array> temps;
auto ensure_row_contiguous_last_dims = [s = stream(), auto ensure_row_contiguous_last_dims = [s = stream(),
&encoder](const array& arr) { &temps](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2]; auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1]; auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) { if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr; return arr;
} else { } else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, arr_cpy, CopyType::General, s); copy_cpu(arr, temps.back(), CopyType::General, s);
encoder.add_temporary(arr_cpy); return temps.back();
return arr_cpy;
} }
}; };
auto x = ensure_row_contiguous_last_dims(x_pre); auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre); auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(scales); encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices); encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices); encoder.set_input_array(rhs_indices);
encoder.set_output_array(out); encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) { encoder.dispatch([out = array::unsafe_weak_copy(out),
auto biases = ensure_row_contiguous_last_dims(inputs[3]); x = array::unsafe_weak_copy(x),
encoder.set_input_array(biases); w = array::unsafe_weak_copy(w),
encoder.dispatch([out = array::unsafe_weak_copy(out), scales = array::unsafe_weak_copy(scales),
x = array::unsafe_weak_copy(x), biases = array::unsafe_weak_copy(biases),
w = array::unsafe_weak_copy(w), lhs_indices = array::unsafe_weak_copy(lhs_indices),
scales = array::unsafe_weak_copy(scales), rhs_indices = array::unsafe_weak_copy(rhs_indices),
biases = array::unsafe_weak_copy(biases), group_size_ = group_size_,
lhs_indices = array::unsafe_weak_copy(lhs_indices), bits_ = bits_,
rhs_indices = array::unsafe_weak_copy(rhs_indices), transpose_ = transpose_]() mutable {
group_size_ = group_size_, _bs_qmm_dispatch(
bits_ = bits_, out,
transpose_ = transpose_]() mutable { x,
_bs_qmm_dispatch( w,
out, scales,
x, biases,
w, lhs_indices,
scales, rhs_indices,
biases, group_size_,
lhs_indices, bits_,
rhs_indices, transpose_);
group_size_, });
bits_,
transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
transpose_ = transpose_]() mutable {
mxfp4_bs_qmm_dispatch(
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
});
}
} }
template <typename T, typename U> template <typename T, typename U>
@@ -1041,7 +705,7 @@ void dispatch_quantize(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size()); w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
} }
void fast::Quantize::eval_cpu( void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
auto ensure_row_contiguous = [s = stream()](const array& arr) { auto ensure_row_contiguous = [s = stream()](const array& arr) {
@@ -1100,47 +764,7 @@ void fast::Quantize::eval_cpu(
} }
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"[fast::Quantize::eval_cpu] Only supports floating point inputs"); "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
});
}
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
} }
}); });
} }

View File

@@ -491,27 +491,19 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
case uint8: case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8: case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break; break;
case int16: case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break; break;
case int64: case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break; break;
case float16: case float16:

View File

@@ -1,6 +1,5 @@
#pragma once #pragma once
#include <arm_neon.h>
#include <simd/math.h> #include <simd/math.h>
#include <simd/vector.h> #include <simd/vector.h>
@@ -10,7 +9,7 @@
#include "mlx/backend/cpu/simd/base_simd.h" #include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in simd/base_simd.h // There seems to be a bug in sims/base.h
// __XROS_2_0 is not defined, the expression evaluates // __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library // to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15 // higher than it should be even on macOS < 15
@@ -201,15 +200,6 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==) SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=) SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N> template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) { Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value); return asd::atan2(a.value, b.value);
@@ -217,20 +207,14 @@ Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
template <typename T, int N> template <typename T, int N>
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) { Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
auto out = Simd<T, N>(asd::max(a.value, b.value)); // TODO add isnan
if constexpr (!std::is_integral_v<T>) { return asd::max(a.value, b.value);
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
} }
template <typename T, int N> template <typename T, int N>
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) { Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
auto out = Simd<T, N>(asd::min(a.value, b.value)); // TODO add isnan
if constexpr (!std::is_integral_v<T>) { return asd::min(a.value, b.value);
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
} }
template <typename T, int N> template <typename T, int N>
@@ -250,7 +234,6 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
template <typename MaskT, typename T1, typename T2, int N> template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) { Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
static_assert(std::is_same_v<MaskT, bool>);
if constexpr (sizeof(T1) == 1) { if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value)); return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) { } else if constexpr (sizeof(T1) == 2) {
@@ -268,13 +251,9 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
return asd::pow(base.value, exp.value); return asd::pow(base.value, exp.value);
} else { } else {
Simd<T, N> res = 1; Simd<T, N> res = 1;
// Raising an integer to a negative power is undefined while (any(exp)) {
if (any(exp < 0)) { res = select(exp & 1, res * base, res);
return 0; base = select(exp, base * base, base);
}
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > 0, base * base, base);
exp = exp >> 1; exp = exp >> 1;
} }
return res; return res;

View File

@@ -171,11 +171,6 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&) DEFAULT_BINARY(&&)
DEFAULT_BINARY(||) DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T> template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) { Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value; T a = a_.value;

View File

@@ -15,18 +15,6 @@ namespace mlx::core {
namespace { namespace {
// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
return true;
}
return a < b;
}
template <typename T> template <typename T>
struct StridedIterator { struct StridedIterator {
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
@@ -39,7 +27,7 @@ struct StridedIterator {
StridedIterator() = default; StridedIterator() = default;
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0) explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: stride_(stride), ptr_(ptr + offset * stride) {} : ptr_(ptr + offset * stride), stride_(stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0) explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {} : StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
@@ -142,7 +130,7 @@ void sort(array& out, int axis) {
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed, nan_aware_less<T>); std::stable_sort(st, ed);
src_it.step(); src_it.step();
} }
} }
@@ -196,15 +184,6 @@ void argsort(const array& in, array& out, int axis) {
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
@@ -240,7 +219,7 @@ void partition(array& out, int axis, int kth) {
StridedIterator md(data_ptr, axis_stride, kth); StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed, nan_aware_less<T>); std::nth_element(st, md, ed);
} }
} }
@@ -297,15 +276,6 @@ void argpartition(const array& in, array& out, int axis, int kth) {
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }

View File

@@ -8,183 +8,6 @@
namespace mlx::core { namespace mlx::core {
template <typename T, class Enable = void>
struct SVDWork {};
template <typename T>
struct SVDWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <>
struct SVDWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
const int lrwork =
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
int lwork_query = -1;
int work_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension.real();
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <typename T> template <typename T>
void svd_impl( void svd_impl(
const array& a, const array& a,
@@ -204,8 +27,6 @@ void svd_impl(
const int N = a.shape(-1); const int N = a.shape(-1);
const int K = std::min(M, N); const int K = std::min(M, N);
using R = typename SVDWork<T>::R;
size_t num_matrices = a.size() / (M * N); size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy. // lapack clobbers the input, so we have to make a copy.
@@ -221,7 +42,7 @@ void svd_impl(
encoder.set_input_array(a); encoder.set_input_array(a);
auto in_ptr = in.data<T>(); auto in_ptr = in.data<T>();
T* u_ptr; T* u_ptr;
R* s_ptr; T* s_ptr;
T* vt_ptr; T* vt_ptr;
if (compute_uv) { if (compute_uv) {
@@ -237,7 +58,7 @@ void svd_impl(
encoder.set_output_array(s); encoder.set_output_array(s);
encoder.set_output_array(vt); encoder.set_output_array(vt);
s_ptr = s.data<R>(); s_ptr = s.data<T>();
u_ptr = u.data<T>(); u_ptr = u.data<T>();
vt_ptr = vt.data<T>(); vt_ptr = vt.data<T>();
} else { } else {
@@ -247,26 +68,124 @@ void svd_impl(
encoder.set_output_array(s); encoder.set_output_array(s);
s_ptr = s.data<R>(); s_ptr = s.data<T>();
u_ptr = nullptr; u_ptr = nullptr;
vt_ptr = nullptr; vt_ptr = nullptr;
} }
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() { encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
auto jobz = (u_ptr) ? 'A' : 'N'; // A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
SVDWork<T> svd_work(N, M, K, jobz); const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto job_u = (u_ptr) ? "V" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info;
// Compute workspace size.
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
svd_work.run( gesvdx<T>(
in_ptr + M * N * i, /* jobu = */ job_u,
s_ptr + K * i, /* jobvt = */ job_vt,
vt_ptr ? vt_ptr + N * N * i : nullptr, /* range = */ range,
u_ptr ? u_ptr + M * M * i : nullptr); // M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
} }
}); });
encoder.add_temporary(in); encoder.add_temporary(in);
} }
template <typename T>
void compute_svd(
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
void SVD::eval_cpu( void SVD::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
@@ -277,12 +196,9 @@ void SVD::eval_cpu(
case float64: case float64:
svd_impl<double>(inputs[0], outputs, compute_uv_, stream()); svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break; break;
case complex64:
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[SVD::eval_cpu] only supports float32, float64, or complex64."); "[SVD::eval_cpu] only supports float32 or float64.");
} }
} }

View File

@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim(); auto ndim = a.ndim();
if (a.flags().contiguous) { if (a.flags().contiguous) {
auto size = a.data_size(); auto size = a.data_size();
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>); constexpr int N = simd::max_size<T>;
while (size >= N) { while (size >= N) {
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src)))); simd::store(dst, Op{}(simd::load<T, N>(src)));
size -= N; size -= N;
src += N; src += N;
dst += N; dst += N;

View File

@@ -77,8 +77,7 @@ struct Real {
struct Sigmoid { struct Sigmoid {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x))); return 1.0f / (1.0f + simd::exp(-x));
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
} }
SINGLE() SINGLE()
}; };
@@ -108,73 +107,4 @@ struct Square {
SINGLE() SINGLE()
}; };
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail } // namespace mlx::core::detail

View File

@@ -8,6 +8,7 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
@@ -16,13 +17,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
@@ -32,7 +28,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
@@ -44,34 +39,24 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
else() else()
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
endif() endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -123,21 +108,10 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>") mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif() endif()
# Use native CUDA arch by default. # Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES) if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
execute_process( set(MLX_CUDA_ARCHITECTURES "native")
COMMAND __nvcc_device_query
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
OUTPUT_STRIP_TRAILING_WHITESPACE)
set(UPGRADABLE_ARCHITECTURES "90;100;121")
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
message(
FATAL_ERROR
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
# Use arch-specific compute capability whenever possible.
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
endif()
endif() endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
@@ -149,7 +123,6 @@ FetchContent_Declare(
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl) FetchContent_MakeAvailable(cccl)
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include") target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
set_target_properties(mlx PROPERTIES CCCL_DIR "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX. # Use fixed version of NVTX.
FetchContent_Declare( FetchContent_Declare(
@@ -175,7 +148,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare( FetchContent_Declare(
cudnn cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.16.0 GIT_TAG v1.12.1
GIT_SHALLOW TRUE GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON) set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
@@ -191,6 +164,7 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers. # Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>) --diag_suppress=997>)
# Install CCCL headers for JIT. # Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -31,20 +30,8 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_; next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
CHECK_CUDA_ERROR(
int device_count = 0; cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
auto curr = next_free_; auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) { for (size_t i = 1; i < num_blocks; ++i) {
@@ -68,7 +55,6 @@ CudaBuffer* SmallSizePool::malloc() {
next_free_ = next_free_->next; next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size; b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size; b->buf.size = small_block_size;
b->buf.device = -1;
return &b->buf; return &b->buf;
} }
@@ -90,42 +76,16 @@ CudaAllocator::CudaAllocator()
page_size, page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) { [this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.9; memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_; max_pool_size_ = memory_limit_;
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
int curr;
CHECK_CUDA_ERROR(cudaGetDevice(&curr));
for (int i = 0; i < device_count; ++i) {
CHECK_CUDA_ERROR(cudaSetDevice(i));
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
}
CHECK_CUDA_ERROR(cudaSetDevice(curr));
} }
void copy_to_managed(CudaBuffer& buf) { Buffer CudaAllocator::malloc(size_t size) {
// TODO maybe make this async on a i/o stream to avoid synchronizing the
// device on malloc/and free
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, buf.size));
buf.device = -1;
CHECK_CUDA_ERROR(cudaMemcpy(new_data, buf.data, buf.size, cudaMemcpyDefault));
CHECK_CUDA_ERROR(cudaFree(buf.data));
buf.data = new_data;
}
Buffer
CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
if (size == 0) {
return Buffer{new CudaBuffer{nullptr, 0, -1}};
}
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size <= small_block_size) { if (size <= small_block_size) {
size = 8; size = 8;
@@ -135,10 +95,6 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
size = page_size * ((size + page_size - 1) / page_size); size = page_size * ((size + page_size - 1) / page_size);
} }
if (size <= small_block_size || stream == nullptr) {
device = -1;
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure try to reclaim memory from the cache. // If we have a lot of memory pressure try to reclaim memory from the cache.
@@ -154,51 +110,30 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
} }
lock.unlock(); lock.unlock();
if (!buf) { if (!buf) {
cudaError_t err; buf = new CudaBuffer{nullptr, size};
void* data = nullptr; cudaError_t err = cudaMallocManaged(&buf->data, size);
if (device == -1) {
err = cudaMallocManaged(&data, size);
} else {
err = cudaMallocAsync(&data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err))); "cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
} }
if (!data) {
return Buffer{nullptr};
}
buf = new CudaBuffer{data, size, device};
} }
lock.lock(); lock.lock();
} }
active_memory_ += buf->size; active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_); peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit. // Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) { if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
// Copy to managed here if the buffer is not on the right device
if (buf->device >= 0 && buf->device != device) {
copy_to_managed(*buf);
}
return Buffer{buf}; return Buffer{buf};
} }
Buffer CudaAllocator::malloc(size_t size) {
return malloc_async(size, -1, nullptr);
}
void CudaAllocator::free(Buffer buffer) { void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr()); auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) { if (!buf) {
return; return;
} }
if (buf->size == 0) {
delete buf;
return;
}
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
active_memory_ -= buf->size; active_memory_ -= buf->size;
@@ -222,11 +157,7 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) { if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf); scalar_pool_.free(buf);
} else { } else {
if (buf->device >= 0) { cudaFree(buf->data);
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
} else {
CHECK_CUDA_ERROR(cudaFree(buf->data));
}
delete buf; delete buf;
} }
} }
@@ -277,17 +208,6 @@ CudaAllocator& allocator() {
return *allocator_; return *allocator_;
} }
Buffer malloc_async(size_t size, CommandEncoder& encoder) {
auto buffer = allocator().malloc_async(
size, encoder.device().cuda_device(), encoder.stream());
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace cu } // namespace cu
namespace allocator { namespace allocator {
@@ -300,11 +220,7 @@ void* Buffer::raw_ptr() {
if (!ptr_) { if (!ptr_) {
return nullptr; return nullptr;
} }
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_); return static_cast<cu::CudaBuffer*>(ptr_)->data;
if (cbuf.device != -1) {
copy_to_managed(cbuf);
}
return cbuf.data;
} }
} // namespace allocator } // namespace allocator

View File

@@ -4,24 +4,19 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/cuda/cuda_utils.h"
#include <cuda_runtime.h>
#include <mutex> #include <mutex>
#include <set> #include <set>
#include <utility> #include <utility>
namespace mlx::core::cu { namespace mlx::core::cu {
class CommandEncoder;
using allocator::Buffer; using allocator::Buffer;
// Stores cuda-managed unified memory. // Stores cuda-managed unified memory.
struct CudaBuffer { struct CudaBuffer {
void* data; void* data;
size_t size; size_t size;
int device; // -1 for managed
}; };
class SmallSizePool { class SmallSizePool {
@@ -50,7 +45,6 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator { class CudaAllocator : public allocator::Allocator {
public: public:
Buffer malloc(size_t size) override; Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, int device, cudaStream_t stream);
void free(Buffer buffer) override; void free(Buffer buffer) override;
size_t size(Buffer buffer) const override; size_t size(Buffer buffer) const override;
@@ -75,12 +69,9 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
SmallSizePool scalar_pool_; SmallSizePool scalar_pool_;
}; };
CudaAllocator& allocator(); CudaAllocator& allocator();
Buffer malloc_async(size_t size, CommandEncoder& encoder);
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -6,33 +6,23 @@
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
namespace cg = cooperative_groups; template <typename T>
struct Arange {
const T start;
const T step;
template <typename T, typename IdxT, int N_WRITES> __device__ T operator()(uint32_t i) const {
__global__ void arange(T* out, IdxT size, T start, T step) { return start + i * step;
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_WRITES > size) {
for (IdxT i = index * N_WRITES; i < size; ++i) {
out[i] = start + i * step;
}
} else {
AlignedVector<T, N_WRITES> out_vec;
#pragma unroll
for (int i = 0; i < N_WRITES; ++i) {
out_vec[i] = start + (index * N_WRITES + i) * step;
}
store_vector<N_WRITES>(out, index, out_vec);
} }
} };
} // namespace cu } // namespace cu
@@ -41,27 +31,24 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_output_array(out); encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag); using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>; using OutType = cuda_type_t<CTYPE>;
constexpr int N_WRITES = 16 / sizeof(OutType); CTYPE step =
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
using IdxT = std::conditional_t<large(), int64_t, int32_t>; thrust::transform(
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES); cu::thrust_policy(encoder.stream()),
encoder.add_kernel_node( thrust::counting_iterator<uint32_t>(0),
cu::arange<OutType, IdxT, N_WRITES>, thrust::counting_iterator<uint32_t>(out.data_size()),
num_blocks, thrust::device_pointer_cast(out.data<OutType>()),
block_dims, cu::Arange<OutType>{
0, static_cast<OutType>(start_), static_cast<OutType>(step)});
gpu_ptr<OutType>(out),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
});
}); });
} }

View File

@@ -140,10 +140,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu"); nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder));
// Prepare the shapes, strides and axis arguments. // Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_); Shape shape = remove_index(in.shape(), axis_);
@@ -156,6 +154,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t ndim = shape.size(); int32_t ndim = shape.size();
// ArgReduce. // ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
@@ -173,8 +172,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks, num_blocks,
block_dim(), block_dim(),
0, 0,
gpu_ptr<T>(in), in.data<T>(),
gpu_ptr<uint32_t>(out), out.data<uint32_t>(),
out.size(), out.size(),
const_param(shape), const_param(shape),
const_param(in_strides), const_param(in_strides),

View File

@@ -99,89 +99,39 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
} }
} }
template < template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_g_nd( __global__ void binary_g_nd(
const In* a, const In* a,
const In* b, const In* b,
Out* out, Out* out,
IdxT size_rest, IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape, const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides, const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) { const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
grid.block_index().y * block.dim_threads().y + block.thread_index().y; index, shape.data(), a_strides.data(), b_strides.data());
if (index_rest >= size_rest) { out[index] = Op{}(a[a_idx], b[b_idx]);
return;
} }
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g( __global__ void binary_g(
const In* a, const In* a,
const In* b, const In* b,
Out* out, Out* out,
IdxT size_rest, IdxT size,
const __grid_constant__ Shape shape, const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides, const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides, const __grid_constant__ Strides b_strides,
int ndim) { int ndim) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = auto [a_idx, b_idx] = elem_to_loc(
grid.block_index().y * block.dim_threads().y + block.thread_index().y; index, shape.data(), a_strides.data(), b_strides.data(), ndim);
if (index_rest >= size_rest) { out[index] = Op{}(a[a_idx], b[b_idx]);
return;
} }
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
template <typename Op, typename In, typename Out> template <typename Op, typename In, typename Out>
@@ -259,61 +209,39 @@ void binary_op_gpu_inplace(
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_g_nd< auto [num_blocks, block_dims] =
Op, get_launch_args(out, large());
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_g_nd<
{num_blocks_x, num_blocks_y}, Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims, block_dims,
0, 0,
gpu_ptr<InType>(a), a.data<InType>(),
gpu_ptr<InType>(b), b.data<InType>(),
gpu_ptr<OutType>(out), out.data<OutType>(),
rest, out.size(),
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides), const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides)); const_param<dims_constant()>(b_strides));
}); });
} else { } else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>; auto [num_blocks, block_dims] = get_launch_args(out, large());
if (work_per_thread == 4) {
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_g<Op, InType, OutType, IdxT>,
{num_blocks_x, num_blocks_y}, num_blocks,
block_dims, block_dims,
0, 0,
gpu_ptr<InType>(a), a.data<InType>(),
gpu_ptr<InType>(b), b.data<InType>(),
gpu_ptr<OutType>(out), out.data<OutType>(),
rest, out.size(),
const_param(shape), const_param(shape),
const_param(a_strides), const_param(a_strides),
const_param(b_strides), const_param(b_strides),
@@ -339,9 +267,9 @@ void binary_op_gpu_inplace(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
gpu_ptr<InType>(a), a.data<InType>(),
gpu_ptr<InType>(b), b.data<InType>(),
gpu_ptr<OutType>(out), out.data<OutType>(),
out.data_size()); out.data_size());
}); });
} }
@@ -365,10 +293,7 @@ void binary_op_gpu(
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s); set_binary_op_output_data(a, b, out, bopt);
set_binary_op_output_data(
a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
binary_op_gpu_inplace<Op>(inputs, out, op, s); binary_op_gpu_inplace<Op>(inputs, out, op, s);
} }
@@ -379,4 +304,54 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, name(), s); \ binary_op_gpu<cu::func>(inputs, out, name(), s); \
} }
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,21 +0,0 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Add)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(ArcTan2)
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Divide)
} // namespace mlx::core

View File

@@ -1,15 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Greater)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(GreaterEqual)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Less)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LessEqual)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogAddExp)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalAnd)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalOr)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Maximum)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Minimum)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Multiply)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(NotEqual)
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More