Compare commits

..

1 Commits

Author SHA1 Message Date
Angelos Katharopoulos
11f73d6e89 Double buffer keys for vector sdpa 2025-04-22 00:19:11 -07:00
501 changed files with 7577 additions and 38051 deletions

View File

@@ -7,9 +7,15 @@ parameters:
nightly_build: nightly_build:
type: boolean type: boolean
default: false default: false
weekly_build:
type: boolean
default: false
test_release: test_release:
type: boolean type: boolean
default: false default: false
linux_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@@ -18,14 +24,13 @@ jobs:
type: boolean type: boolean
default: false default: false
macos: macos:
xcode: "26.0.0" xcode: "16.2.0"
resource_class: m4pro.medium resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
name: Install name: Install
command: | command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9 brew install python@3.9
brew install doxygen brew install doxygen
python3.9 -m venv env python3.9 -m venv env
@@ -33,7 +38,7 @@ jobs:
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install -r docs/requirements.txt pip install -r docs/requirements.txt
pip install . -v CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when: - when:
condition: condition:
not: << parameters.upload-docs >> not: << parameters.upload-docs >>
@@ -65,9 +70,9 @@ jobs:
git push -f origin gh-pages git push -f origin gh-pages
linux_build_and_test: linux_build_and_test:
machine: docker:
image: ubuntu-2204:current - image: cimg/python:3.9
resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
@@ -79,36 +84,36 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
export DEBIAN_FRONTEND=noninteractive pip install --upgrade cmake
export NEEDRESTART_MODE=a pip install nanobind==2.4.0
pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
uv venv CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
uv pip install cmake CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ python3 setup.py build_ext --inplace
uv pip install -e ".[dev]" -v CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
uv pip install typing_extensions echo "stubs"
uv run --no-project setup.py generate_stubs pip install typing_extensions
python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
source .venv/bin/activate python3 -m unittest discover python/tests -v
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
source .venv/bin/activate
mkdir -p build && cd build mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc` make -j `nproc`
@@ -120,7 +125,7 @@ jobs:
parameters: parameters:
xcode_version: xcode_version:
type: string type: string
default: "26.0.0" default: "16.2.0"
macosx_deployment_target: macosx_deployment_target:
type: string type: string
default: "" default: ""
@@ -128,56 +133,57 @@ jobs:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
environment: environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >> MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m4pro.medium resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
xcodebuild -downloadComponent MetalToolchain brew install python@3.9
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \ brew install openmpi
brew install openmpi uv python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
uv venv --python 3.9 source env/bin/activate
uv pip install \ DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
nanobind==2.4.0 \ CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
cmake \ pip install -e . -v
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
uv pip install typing_extensions source env/bin/activate
uv run --no-project setup.py generate_stubs pip install typing_extensions
python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
source .venv/bin/activate source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
source .venv/bin/activate source env/bin/activate
cd examples/extensions cd examples/extensions
uv pip install -r requirements.txt pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace python setup.py build_ext -j8
uv run --no-project python test.py
- store_test_results: - store_test_results:
path: test-results path: test-results
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
source .venv/bin/activate source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu` mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run: - run:
name: Run CPP tests name: Run CPP tests
@@ -186,7 +192,7 @@ jobs:
- run: - run:
name: Build small binary name: Build small binary
command: | command: |
source .venv/bin/activate source env/bin/activate
cd build/ cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \ cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \ -DBUILD_SHARED_LIBS=ON \
@@ -198,76 +204,13 @@ jobs:
- run: - run:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ source env/bin/activate
uv pip install -e . -v CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \ METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \ python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release: build_release:
parameters: parameters:
@@ -276,7 +219,7 @@ jobs:
default: "3.9" default: "3.9"
xcode_version: xcode_version:
type: string type: string
default: "26.0.0" default: "16.2.0"
build_env: build_env:
type: string type: string
default: "" default: ""
@@ -285,7 +228,7 @@ jobs:
default: "" default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: m4pro.medium resource_class: m2pro.medium
environment: environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >> MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps: steps:
@@ -293,15 +236,11 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
xcodebuild -downloadComponent MetalToolchain brew install python@<< parameters.python_version >>
mkdir -p ~/miniconda3 brew install openmpi
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh python<< parameters.python_version >> -m venv env
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 source env/bin/activate
rm ~/miniconda3/miniconda.sh pip install --upgrade pip
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0 pip install nanobind==2.4.0
pip install --upgrade setuptools pip install --upgrade setuptools
@@ -311,38 +250,30 @@ jobs:
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
conda activate env source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
conda activate env source env/bin/activate
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
conda activate env source env/bin/activate
python setup.py clean --all << parameters.build_env >> \
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
- when: python -m build -w
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when: - when:
condition: << parameters.build_env >> condition: << parameters.build_env >>
steps: steps:
- run: - run:
name: Upload package name: Upload package
command: | command: |
conda activate env source env/bin/activate
twine upload dist/* twine upload dist/*
- store_artifacts: - store_artifacts:
path: dist/ path: dist/
@@ -352,100 +283,52 @@ jobs:
python_version: python_version:
type: string type: string
default: "3.9" default: "3.9"
build_env: extra_env:
type: string type: string
default: "" default: "DEV_RELEASE=1"
machine: docker:
image: ubuntu-2204:current - image: ubuntu:20.04
resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
name: Build wheel name: Build wheel
command: | command: |
PYTHON=python<< parameters.python_version >> PYTHON=python<< parameters.python_version >>
export DEBIAN_FRONTEND=noninteractive apt-get update
export NEEDRESTART_MODE=a apt-get upgrade -y
sudo apt-get update DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
TZ=Etc/UTC sudo apt-get -y install tzdata apt-get install -y apt-utils
sudo add-apt-repository -y ppa:deadsnakes/ppa apt-get install -y software-properties-common
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env $PYTHON -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
<< parameters.build_env >> pip install ".[dev]" -v << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
python setup.py clean --all << parameters.extra_env >> \
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
bash python/scripts/repair_linux.sh python -m build --wheel
- when: auditwheel show dist/*
condition: auditwheel repair dist/* --plat manylinux_2_31_x86_64
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: xlarge
steps:
- checkout
- run: - run:
name: Build wheel name: Upload package
command: | command: |
export DEBIAN_FRONTEND=noninteractive source env/bin/activate
export NEEDRESTART_MODE=a twine upload wheelhouse/*
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@@ -457,23 +340,21 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$" pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >> value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >> - not: << pipeline.parameters.test_release >>
jobs: jobs:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
macosx_deployment_target: ["13.5", "15.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test - linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation - build_documentation
build_pypi_release: build_pypi_release:
when: when:
and: and:
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >> - not: << pipeline.parameters.test_release >>
jobs: jobs:
- build_release: - build_release:
@@ -487,7 +368,68 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"] xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation: - build_documentation:
filters: filters:
tags: tags:
@@ -495,25 +437,6 @@ workflows:
branches: branches:
ignore: /.*/ ignore: /.*/
upload-docs: true upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb: prb:
when: when:
@@ -529,14 +452,9 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
macosx_deployment_target: ["13.5", "15.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build: nightly_build:
when: when:
and: and:
@@ -548,18 +466,58 @@ workflows:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"] xcode_version: ["16.2.0", "15.0.0"]
- build_linux_release: exclude:
matrix: - macosx_deployment_target: "13.5"
parameters: xcode_version: "16.2.0"
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: "3.9"
- build_cuda_release - macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
build_dev_release: python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build:
when: when:
and: and:
- equal: [ main, << pipeline.git.branch >> ] - equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >> - << pipeline.parameters.weekly_build >>
jobs: jobs:
- build_release: - build_release:
matrix: matrix:
@@ -567,13 +525,76 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"] xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release: - build_linux_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["DEV_RELEASE=1"] extra_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

1
.gitignore vendored
View File

@@ -36,7 +36,6 @@ share/python-wheels/
.installed.cfg .installed.cfg
*.egg *.egg
MANIFEST MANIFEST
uv.lock
# vim # vim
*.swp *.swp

View File

@@ -19,17 +19,11 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a> </a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software # Third-Party Software
MLX leverages several third-party software, listed here together with MLX leverages several third-party software, listed here together with

View File

@@ -26,7 +26,6 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER) set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration ----------------------------- # ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON) option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -35,16 +34,13 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON) option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON) option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF) option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message( message(
@@ -67,17 +63,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.") message(WARNING "Building for x86_64 arch is not officially supported.")
endif() endif()
endif() endif()
else() else()
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
endif() message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
endif() endif()
# ----------------------------- Lib ----------------------------- # ----------------------------- Lib -----------------------------
@@ -88,21 +77,18 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
if(MLX_BUILD_CUDA) if(MLX_BUILD_METAL)
enable_language(CUDA) set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif() endif()
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL AND NOT METAL_LIB)
find_library(METAL_LIB Metal) message(STATUS "Metal not found. Unable to build GPU")
find_library(FOUNDATION_LIB Foundation) set(MLX_BUILD_METAL OFF)
find_library(QUARTZ_LIB QuartzCore) set(MLX_METAL_DEBUG OFF)
if(METAL_LIB) elseif(MLX_BUILD_METAL)
message(STATUS "Metal found ${METAL_LIB}") message(STATUS "Building METAL sources")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
@@ -111,8 +97,7 @@ if(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0) if(${MACOS_SDK_VERSION} LESS 14.0)
message( message(
@@ -141,12 +126,6 @@ if(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif() endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32) if(WIN32)
if(MSVC) if(MSVC)
# GGUF does not build with MSVC. # GGUF does not build with MSVC.
@@ -174,7 +153,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
else() else()
message(STATUS "Accelerate not found, using default backend.") message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
@@ -247,19 +226,12 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>) $<INSTALL_INTERFACE:include>)
# Do not add mlx_EXPORTS define for shared library. FetchContent_Declare(
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "") fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
if(USE_SYSTEM_FMT) GIT_TAG 10.2.1
find_package(fmt REQUIRED) EXCLUDE_FROM_ALL)
else() FetchContent_MakeAvailable(fmt)
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)

View File

@@ -1,6 +1,4 @@
include CMakeLists.txt include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ * recursive-include mlx/ *
include cmake/*
include python/src/* include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561 include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -11,31 +11,31 @@ brought to you by Apple machine learning research.
Some key features of MLX include: Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX - **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building `mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models. more complex models.
- **Composable function transformations**: MLX supports composable function - **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization, transformations for automatic differentiation, automatic vectorization,
and computation graph optimization. and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only - **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed. materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed - **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive. slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices - **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU). (currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks - **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory. is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported Operations on MLX arrays can be performed on any of the supported
device types without transferring data. device types without transferring data.
MLX is designed by machine learning researchers for machine learning MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient researchers. The framework is intended to be user-friendly, but still efficient
@@ -68,23 +68,18 @@ in the documentation.
## Installation ## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
macOS, run:
```bash **With `pip`**:
```
pip install mlx pip install mlx
``` ```
To install the CUDA backend on Linux, run: **With `conda`**:
```bash
pip install mlx[cuda]
``` ```
conda install -c conda-forge mlx
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
``` ```
Checkout the Checkout the
@@ -110,7 +105,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following MLX useful in your research and wish to cite it, please use the following
BibTex entry: BibTex entry:
```text ```
@software{mlx2023, @software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert}, author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon}, title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>

View File

@@ -192,22 +192,6 @@ void time_reductions() {
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); }; auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1); TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
} }
void time_gather_scatter() { void time_gather_scatter() {

View File

@@ -142,7 +142,9 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype) c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4 atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -161,7 +163,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks") parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16", "complex64") dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn") transposes = ("nn", "nt", "tn")
shapes = ( shapes = (
(16, 234, 768, 3072), (16, 234, 768, 3072),
@@ -185,7 +187,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0 diff = gflops_mx / gflops_pt - 1.0
print( print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%" f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
) )
if gflops_pt >= 2.0 * gflops_mx: if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^") print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True): for transpose in (False, True):
for dtype in ("float32", "float16", "complex64"): for dtype in ("float32", "float16"):
fig, axs = plt.subplots( fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
) )
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}") fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig( fig.savefig(
os.path.join( os.path.join(
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf" results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
) )
) )
plt.close(fig) plt.close(fig)

View File

@@ -5,7 +5,6 @@ import os
import time import time
import torch import torch
import torch.cuda
import torch.mps import torch.mps
@@ -45,10 +44,8 @@ def bench(f, *args):
def sync_if_needed(x): def sync_if_needed(x):
if x.device == torch.device("mps"): if x.device != torch.device("cpu"):
torch.mps.synchronize() torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad() @torch.no_grad()
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def softmax(axis, x): def softmax(axis, x):
ys = [] ys = []
@@ -351,11 +340,7 @@ if __name__ == "__main__":
args.axis.pop(0) args.axis.pop(0)
torch.set_num_threads(1) torch.set_num_threads(1)
device = "mps" device = "cpu" if args.cpu else "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
types = args.dtype types = args.dtype
if not types: if not types:
@@ -475,8 +460,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu": elif args.benchmark == "selu":
print(bench(selu, x)) print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else: else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.") raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -1,107 +0,0 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,7 +1,5 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from time_utils import time_fn from time_utils import time_fn
@@ -20,63 +18,51 @@ def layer_norm(x, w, b, eps):
return y return y
def time_layer_norm(N, dt): def time_layer_norm():
L = 1024
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum() f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum() f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2)) g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2)) g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, L, N)).astype(dt) x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(N,)).astype(dt) w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(N,)).astype(dt) b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, L, N)).astype(dt) y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y) mx.eval(x, w, b, y)
def layer_norm_loop(f, x, w, b): def layer_norm_loop(g, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
gx, gw, gb = x, w, b gx, gw, gb = x, w, b
for _ in range(32): for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y) gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb return gx, gw, gb
time_fn(layer_norm_grad_loop, g1, x, w, b) time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_grad_loop, g2, x, w, b) time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b) time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b) time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum() f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum() f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,)) g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,)) g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, L, N)).astype(dt) x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(N,)).astype(dt) w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(N,)).astype(dt) b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, L, N)).astype(dt) y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y) mx.eval(x, w, b, y)
def layer_norm_grad_x_loop(g, x): def layer_norm_loop(g, x):
gx = x gx = x
for _ in range(32): for _ in range(32):
gx = g(gx, y) gx = g(gx, y)
return gx return gx
time_fn(layer_norm_grad_x_loop, g1, x) time_fn(layer_norm_loop, g1, x)
time_fn(layer_norm_grad_x_loop, g2, x) time_fn(layer_norm_loop, g2, x)
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x) time_fn(layer_norm_loop, mx.compile(g1), x)
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x) time_fn(layer_norm_loop, mx.compile(g2), x)
if __name__ == "__main__": if __name__ == "__main__":
for dt in [mx.float32, mx.float16, mx.bfloat16]: time_layer_norm()
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)

View File

@@ -4,7 +4,7 @@ import math
import mlx.core as mx import mlx.core as mx
from time_utils import time_fn from time_utils import time_fn
L = 16384 L = 1024
H = 32 H = 32
H_k = H // 4 H_k = H // 4
D = 128 D = 128

View File

@@ -51,20 +51,6 @@ def time_maximum():
time_fn(mx.maximum, a, b) time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative(): def time_negative():
a = mx.random.uniform(shape=(10000, 1000)) a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a) mx.eval(a)
@@ -122,8 +108,6 @@ if __name__ == "__main__":
time_add() time_add()
time_matmul() time_matmul()
time_min()
time_max()
time_maximum() time_maximum()
time_exp() time_exp()
time_negative() time_negative()

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -11,14 +11,13 @@ include(CMakeParseArguments)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of # Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers) DEBUG: Boolean, if true, enables debug compile options # files (like headers)
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
# #
# clang format on # clang format on
macro(mlx_build_metallib) macro(mlx_build_metallib)
# Parse args # Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG) set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
@@ -27,10 +26,6 @@ macro(mlx_build_metallib)
# Collect compile options # Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
# Prepare metallib build command # Prepare metallib build command
add_custom_command( add_custom_command(

View File

@@ -1,5 +1,4 @@
sphinx sphinx
breathe breathe
sphinx-book-theme sphinx-book-theme
sphinx-copybutton
mlx mlx

View File

@@ -10,7 +10,7 @@ import mlx.core as mx
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "MLX" project = "MLX"
copyright = "2023, Apple" copyright = "2023, MLX Contributors"
author = "MLX Contributors" author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3]) version = ".".join(mx.__version__.split(".")[:3])
release = version release = version
@@ -18,7 +18,6 @@ release = version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
extensions = [ extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc", "sphinx.ext.autodoc",
"sphinx.ext.autosummary", "sphinx.ext.autosummary",
"sphinx.ext.intersphinx", "sphinx.ext.intersphinx",

View File

@@ -8,26 +8,23 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example Simple Example
-------------- --------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise: Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python .. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
def exp_elementwise(a: mx.array): def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@@ -42,13 +39,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
b = exp_elementwise(a) b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a)) assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note:: .. note::
Only pass the body of the Metal kernel in ``source``. The function We are only required to pass the body of the Metal kernel in ``source``.
signature is generated automatically.
The full function signature will be generated using: The full function signature will be generated using:
@@ -86,52 +78,44 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>; template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
generated code for debugging purposes.
Using Shape/Strides Using Shape/Strides
------------------- -------------------
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which ``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
is ``True`` by default. This will copy the array inputs if needed This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
before the kernel is launched to ensure that the memory layout is row Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
contiguous. Generally this makes writing the kernel easier, since we don't when indexing.
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are input array ``a`` if any are present in ``source``.
present in ``source``. We can then use MLX's built in indexing utils to fetch We can then use MLX's built in indexing utils to fetch the right elements for each thread.
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python .. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source,
ensure_row_contiguous=False,
)
def exp_elementwise(a: mx.array): def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@@ -139,6 +123,7 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes=[a.shape], output_shapes=[a.shape],
output_dtypes=[a.dtype], output_dtypes=[a.dtype],
ensure_row_contiguous=False,
) )
return outputs[0] return outputs[0]
@@ -157,139 +142,137 @@ We'll start with the following MLX implementation using standard ops:
.. code-block:: python .. code-block:: python
def grid_sample_ref(x, grid): def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2 ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2 iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32) ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32) iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1 ix_ne = ix_nw + 1
iy_ne = iy_nw iy_ne = iy_nw
ix_sw = ix_nw ix_sw = ix_nw
iy_sw = iy_nw + 1 iy_sw = iy_nw + 1
ix_se = ix_nw + 1 ix_se = ix_nw + 1
iy_se = iy_nw + 1 iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy) nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy) ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne) sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw) se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None] I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None] I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None] I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None] I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output return output
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel` Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
to write a fast GPU kernel for both the forward and backward passes. to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel: First we'll implement the forward pass as a fused kernel:
.. code-block:: python .. code-block:: python
source = """ @mx.custom_function
uint elem = thread_position_in_grid.x; def grid_sample(x, grid):
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C; assert x.ndim == 4, "`x` must be 4D."
int h_stride = W * w_stride; assert grid.ndim == 4, "`grid` must be 4D."
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2; B, _, _, C = x.shape
float ix = ((grid[grid_idx] + 1) * W - 1) / 2; _, gN, gM, D = grid.shape
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; out_shape = (B, gN, gM, C)
int ix_nw = floor(ix); assert D == 2, "Last dim of `grid` must be size 2."
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1; source = """
int iy_ne = iy_nw; uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int ix_sw = ix_nw; int w_stride = C;
int iy_sw = iy_nw + 1; int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_se = ix_nw + 1; uint grid_idx = elem / C * 2;
int iy_se = iy_nw + 1; float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
T nw = (ix_se - ix) * (iy_se - iy); int ix_nw = floor(ix);
T ne = (ix - ix_sw) * (iy_sw - iy); int iy_nw = floor(iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride; int ix_ne = ix_nw + 1;
int channel_idx = elem % C; int iy_ne = iy_nw;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; int ix_sw = ix_nw;
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; int iy_sw = iy_nw + 1;
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; int ix_se = ix_nw + 1;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; int iy_se = iy_nw + 1;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; T nw = (ix_se - ix) * (iy_se - iy);
""" T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
kernel = mx.fast.metal_kernel( int batch_idx = elem / C / gH / gW * b_stride;
name="grid_sample", int channel_idx = elem % C;
input_names=["x", "grid"], int base_idx = batch_idx + channel_idx;
output_names=["out"],
source=source,
)
@mx.custom_function T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
def grid_sample(x, grid): T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
assert x.ndim == 4, "`x` must be 4D." I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
assert grid.ndim == 4, "`grid` must be 4D." I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
B, _, _, C = x.shape out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
_, gN, gM, D = grid.shape """
out_shape = (B, gN, gM, C) kernel = mx.fast.metal_kernel(
name="grid_sample",
assert D == 2, "Last dim of `grid` must be size 2." input_names=["x", "grid"],
output_names=["out"],
outputs = kernel( source=source,
inputs=[x, grid], )
template=[("T", x.dtype)], outputs = kernel(
output_shapes=[out_shape], inputs=[x, grid],
output_dtypes=[x.dtype], template=[("T", x.dtype)],
grid=(np.prod(out_shape), 1, 1), output_shapes=[out_shape],
threadgroup=(256, 1, 1), output_dtypes=[x.dtype],
) grid=(np.prod(out_shape), 1, 1),
return outputs[0] threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as: For a reasonably sized input such as:
.. code-block:: python .. code-block:: python
x.shape = (8, 1024, 1024, 64) x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2) grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement: On an M1 Max, we see a big performance improvement:
@@ -298,11 +281,11 @@ On an M1 Max, we see a big performance improvement:
Grid Sample VJP Grid Sample VJP
--------------- ---------------
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
define its custom vjp transform so MLX can differentiate it. its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra :func:`fast.metal_kernel` features: requires a few extra ``mx.fast.metal_kernel`` features:
* ``init_value=0`` * ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
@@ -316,129 +299,128 @@ We can then implement the backwards pass as follows:
.. code-block:: python .. code-block:: python
source = """ @grid_sample.vjp
uint elem = thread_position_in_grid.x; def grid_sample_vjp(primals, cotangent, _):
int H = x_shape[1]; x, grid = primals
int W = x_shape[2]; B, _, _, C = x.shape
int C = x_shape[3]; _, gN, gM, D = grid.shape
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1]; assert D == 2, "Last dim of `grid` must be size 2."
int gW = grid_shape[2];
int w_stride = C; source = """
int h_stride = W * w_stride; uint elem = thread_position_in_grid.x;
int b_stride = H * h_stride; int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
uint grid_idx = elem / C_padded * 2; int gH = grid_shape[1];
float ix = ((grid[grid_idx] + 1) * W - 1) / 2; int gW = grid_shape[2];
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix); int w_stride = C;
int iy_nw = floor(iy); int h_stride = W * w_stride;
int b_stride = H * h_stride;
int ix_ne = ix_nw + 1; uint grid_idx = elem / C_padded * 2;
int iy_ne = iy_nw; float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_sw = ix_nw; int ix_nw = floor(ix);
int iy_sw = iy_nw + 1; int iy_nw = floor(iy);
int ix_se = ix_nw + 1; int ix_ne = ix_nw + 1;
int iy_se = iy_nw + 1; int iy_ne = iy_nw;
T nw = (ix_se - ix) * (iy_se - iy); int ix_sw = ix_nw;
T ne = (ix - ix_sw) * (iy_sw - iy); int iy_sw = iy_nw + 1;
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C_padded / gH / gW * b_stride; int ix_se = ix_nw + 1;
int channel_idx = elem % C_padded; int iy_se = iy_nw + 1;
int base_idx = batch_idx + channel_idx;
T gix = T(0); T nw = (ix_se - ix) * (iy_se - iy);
T giy = T(0); T ne = (ix - ix_sw) * (iy_sw - iy);
if (channel_idx < C) { T sw = (ix_ne - ix) * (iy - iy_ne);
int cot_index = elem / C_padded * C + channel_idx; T se = (ix - ix_nw) * (iy - iy_nw);
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset]; int batch_idx = elem / C_padded / gH / gW * b_stride;
gix -= I_nw * (iy_se - iy) * cot; int channel_idx = elem % C_padded;
giy -= I_nw * (ix_se - ix) * cot; int base_idx = batch_idx + channel_idx;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_ne = x[offset]; T gix = T(0);
gix += I_ne * (iy_sw - iy) * cot; T giy = T(0);
giy -= I_ne * (ix - ix_sw) * cot; if (channel_idx < C) {
} int cot_index = elem / C_padded * C + channel_idx;
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { T cot = cotangent[cot_index];
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_sw = x[offset]; T I_nw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot; gix -= I_nw * (iy_se - iy) * cot;
giy += I_sw * (ix_ne - ix) * cot; giy -= I_nw * (ix_se - ix) * cot;
} }
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride; int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_se = x[offset]; T I_ne = x[offset];
gix += I_se * (iy - iy_nw) * cot; gix += I_ne * (iy_sw - iy) * cot;
giy += I_se * (ix - ix_nw) * cot; giy -= I_ne * (ix - ix_sw) * cot;
} }
} if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T gix_mult = W / 2; T I_sw = x[offset];
T giy_mult = H / 2; gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
// Reduce across each simdgroup first. T I_se = x[offset];
// This is much faster than relying purely on atomics. gix += I_se * (iy - iy_nw) * cot;
gix = simd_sum(gix); giy += I_se * (ix - ix_nw) * cot;
giy = simd_sum(giy); }
}
if (thread_index_in_simdgroup == 0) { T gix_mult = W / 2;
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); T giy_mult = H / 2;
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
@grid_sample.vjp // Reduce across each simdgroup first.
def grid_sample_vjp(primals, cotangent, _): // This is much faster than relying purely on atomics.
x, grid = primals gix = simd_sum(gix);
B, _, _, C = x.shape giy = simd_sum(giy);
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2." if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
# pad the output channels to simd group size atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
# so that our `simd_sum`s don't overlap. }
simdgroup_size = 32 """
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size kernel = mx.fast.metal_kernel(
grid_size = B * gN * gM * C_padded name="grid_sample_grad",
outputs = kernel( input_names=["x", "grid", "cotangent"],
inputs=[x, grid, cotangent], output_names=["x_grad", "grid_grad"],
template=[("T", x.dtype)], source=source,
output_shapes=[x.shape, grid.shape], atomic_outputs=True,
output_dtypes=[x.dtype, x.dtype], )
grid=(grid_size, 1, 1), # pad the output channels to simd group size
threadgroup=(256, 1, 1), # so that our `simd_sum`s don't overlap.
init_value=0, simdgroup_size = 32
) C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
return outputs[0], outputs[1] grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp: There's an even larger speed up for the vjp:

View File

@@ -138,13 +138,13 @@ more concrete:
* representing the vectorized computation and the axis which * representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension. * corresponds to the output vectorized dimension.
*/ */
std::pair<std::vector<array>, std::vector<int>> vmap( virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** The name of primitive. */ /** Print the primitive. */
const char* name() const override { void print(std::ostream& os) override {
return "Axpby"; os << "Axpby";
} }
/** Equivalence check **/ /** Equivalence check **/
@@ -394,14 +394,14 @@ below.
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel // Resolve name of kernel
std::stream kname; std::ostringstream kname;
kname = "axpby_general_" + type_to_name(out); kname << "axpby_" << "general_" << type_to_name(out);
// Load the metal library // Make sure the metal library is available
auto lib = d.get_library("mlx_ext", current_binary_dir()); d.register_library("mlx_ext");
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib); auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -70,7 +70,6 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/cuda
python/memory_management python/memory_management
python/nn python/nn
python/optimizers python/optimizers

View File

@@ -13,7 +13,7 @@ silicon computer is
pip install mlx pip install mlx
To install from PyPI your system must meet the following requirements: To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon) - Using an M series chip (Apple silicon)
- Using a native Python >= 3.9 - Using a native Python >= 3.9
@@ -23,39 +23,12 @@ To install from PyPI your system must meet the following requirements:
MLX is only available on devices running macOS >= 13.5 MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma) It is highly recommended to use macOS 14 (Sonoma)
CUDA
^^^^
MLX has a CUDA backend which you can install with: MLX is also available on conda-forge. To install MLX with conda do:
.. code-block:: shell .. code-block:: shell
pip install mlx[cuda] conda install conda-forge::mlx
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
Troubleshooting Troubleshooting
@@ -92,8 +65,6 @@ Build Requirements
Python API Python API
^^^^^^^^^^ ^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_: `its GitHub repo <https://github.com/ml-explore/mlx>`_:
@@ -105,20 +76,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
pip install . CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
For developing, install the package with development dependencies, and use an For developing, install the package with development dependencies, and use an
editable install: editable install:
.. code-block:: shell .. code-block:: shell
pip install -e ".[dev]" CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with: Once the development dependencies are installed, you can build faster with:
.. code-block:: shell .. code-block:: shell
python setup.py build_ext --inplace CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
Run the tests with: Run the tests with:
@@ -136,8 +107,6 @@ IDE:
C++ API C++ API
^^^^^^^ ^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source. Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start Similarly to the python library, to build and install the MLX C++ library start
@@ -216,7 +185,6 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version xcrun -sdk macosx --show-sdk-version
Binary Size Minimization Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
@@ -245,50 +213,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots. Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -19,8 +19,6 @@ Array
array.ndim array.ndim
array.shape array.shape
array.size array.size
array.real
array.imag
array.abs array.abs
array.all array.all
array.any array.any

View File

@@ -1,9 +0,0 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,4 +13,3 @@ Fast
rope rope
scaled_dot_product_attention scaled_dot_product_attention
metal_kernel metal_kernel
cuda_kernel

View File

@@ -20,5 +20,3 @@ FFT
irfft2 irfft2
rfftn rfftn
irfftn irfftn
fftshift
ifftshift

View File

@@ -16,8 +16,6 @@ Linear Algebra
cross cross
qr qr
svd svd
eigvals
eig
eigvalsh eigvalsh
eigh eigh
lu lu

View File

@@ -27,7 +27,6 @@ simple functions.
mish mish
prelu prelu
relu relu
relu2
relu6 relu6
selu selu
sigmoid sigmoid

View File

@@ -50,7 +50,6 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU2
ReLU6 ReLU6
RNN RNN
RoPE RoPE

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state, destination={}) state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", state) mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors")) state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -19,4 +19,3 @@ Common Optimizers
Adamax Adamax
Lion Lion
MultiOptimizer MultiOptimizer
Muon

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python .. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096)) x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x) timeit(nn.gelu, x)
timeit(mx.compile(gelu), x) timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
def fun(x, y): def fun(x, y):
z = x + y z = x + y
state.append(z) state.append(z)
return mx.exp(z) return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0)) fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)] # Prints [array(3, dtype=float32)]

View File

@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y): def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y) loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss

View File

@@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
model.update(tree_unflatten(list(params.items()))) model.update(tree_unflatten(list(params.items())))
return model(x) return model(x)
params = tree_flatten(model.parameters(), destination={}) params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -164,11 +164,11 @@ to export a function which can be used for inputs with variable shapes:
.. code-block:: python .. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True) mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn") imported_abs = mx.import_function("fun.mlxfn")
# Ok # Ok
out, = imported_abs(mx.array([-1.0])) out, = imported_abs(mx.array(-1.0))
# Also ok # Also ok
out, = imported_abs(mx.array([-1.0, -2.0])) out, = imported_abs(mx.array([-1.0, -2.0]))

View File

@@ -107,28 +107,6 @@ same array:
>>> a >>> a
array([1, 2, 0], dtype=int32) array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])
The first element of ``a`` could be ``4`` or ``5``.
Transformations of functions which use in-place updates are allowed and work as Transformations of functions which use in-place updates are allowed and work as
expected. For example: expected. For example:

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023-2025 Apple Inc.
#include <dlfcn.h>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
@@ -17,19 +16,6 @@
namespace my_ext { namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation Implementation // Operation Implementation
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -181,15 +167,16 @@ void Axpby::eval_gpu(
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)
std::string kname = "axpby_"; std::ostringstream kname;
kname += (contiguous_kernel ? "contiguous_" : "general_"); kname << "axpby_";
kname += type_to_name(out); kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Load the metal library // Make sure the metal library is available
auto lib = d.get_library("mlx_ext", current_binary_dir()); d.register_library("mlx_ext");
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib); auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
const std::vector<mx::array>& inputs, const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** The name of primitive. */ /** Print the primitive. */
const char* name() const override { void print(std::ostream& os) override {
return "Axpby"; os << "Axpby";
} }
/** Equivalence check **/ /** Equivalence check **/

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.25
mlx>=0.21.0 mlx>=0.21.0
nanobind==2.4.0 nanobind==2.2.0

View File

@@ -3,10 +3,8 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4)) a = mx.ones((3, 4))
b = mx.ones((3, 4)) b = mx.ones((3, 4))
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu) c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
print(f"c shape: {c_cpu.shape}") print(f"c shape: {c.shape}")
print(f"c dtype: {c_cpu.dtype}") print(f"c dtype: {c.dtype}")
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}") print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")

View File

@@ -21,7 +21,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file. # Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}") target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
@@ -49,19 +49,5 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else() else()
target_sources(mlx add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
endif() endif()

View File

@@ -10,7 +10,6 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/dtype.h" #include "mlx/dtype.h"
#include "mlx/event.h" #include "mlx/event.h"
#include "mlx/small_vector.h"
namespace mlx::core { namespace mlx::core {
@@ -19,8 +18,8 @@ class Primitive;
using Deleter = std::function<void(allocator::Buffer)>; using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t; using ShapeElem = int32_t;
using Shape = SmallVector<ShapeElem>; using Shape = std::vector<ShapeElem>;
using Strides = SmallVector<int64_t>; using Strides = std::vector<int64_t>;
class array { class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc /* An array is really a node in a graph. It contains a shared ArrayDesc
@@ -225,10 +224,6 @@ class array {
// Not copyable // Not copyable
Data(const Data& d) = delete; Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete; Data& operator=(const Data& d) = delete;
Data(Data&& o) : buffer(o.buffer), d(o.d) {
o.buffer = allocator::Buffer(nullptr);
o.d = [](allocator::Buffer) {};
}
~Data() { ~Data() {
d(buffer); d(buffer);
} }
@@ -361,7 +356,7 @@ class array {
} }
enum Status { enum Status {
// The output of a computation which has not been scheduled. // The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`. // For example, the status of `x` in `auto x = a + b`.
unscheduled, unscheduled,

View File

@@ -1,157 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cassert>
#include <functional>
#include <map>
namespace mlx::core {
template <typename T>
class BufferCache {
public:
BufferCache(
size_t page_size,
std::function<size_t(T*)> get_size,
std::function<void(T*)> free)
: page_size_(page_size),
get_size_(std::move(get_size)),
free_(std::move(free)) {}
~BufferCache() {
clear();
}
BufferCache(const BufferCache&) = delete;
BufferCache& operator=(const BufferCache&) = delete;
T* reuse_from_cache(size_t size) {
// Find the closest buffer in pool.
auto it = buffer_pool_.lower_bound(size);
if (it == buffer_pool_.end() ||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
return nullptr;
}
// Collect from the cache.
T* buf = it->second->buf;
pool_size_ -= it->first;
// Remove from record.
remove_from_list(it->second);
buffer_pool_.erase(it);
return buf;
}
void recycle_to_cache(T* buf) {
assert(buf);
// Add to cache.
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
size_t size = get_size_(buf);
pool_size_ += size;
buffer_pool_.emplace(size, bh);
}
int release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
// Release buffer.
size_t size = get_size_(tail_->buf);
total_bytes_freed += size;
free_(tail_->buf);
n_release++;
// Remove from record.
auto its = buffer_pool_.equal_range(size);
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
return el.second == tail_;
});
assert(it != buffer_pool_.end());
buffer_pool_.erase(it);
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
int clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
free_(holder->buf);
n_release++;
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
size_t cache_size() const {
return pool_size_;
}
size_t page_size() const {
return page_size_;
}
private:
struct BufferHolder {
public:
explicit BufferHolder(T* buf_) : buf(buf_) {}
BufferHolder* prev{nullptr};
BufferHolder* next{nullptr};
T* buf;
};
void add_at_head(BufferHolder* to_add) {
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void remove_from_list(BufferHolder* to_remove) {
if (to_remove->prev && to_remove->next) { // if middle
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // if tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // if head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // if only element
head_ = nullptr;
tail_ = nullptr;
}
delete to_remove;
}
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_{nullptr};
BufferHolder* tail_{nullptr};
size_t pool_size_{0};
const size_t page_size_;
std::function<size_t(T*)> get_size_;
std::function<void(T*)> free_;
};
} // namespace mlx::core

View File

@@ -1,7 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h" #include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -14,8 +15,6 @@ void print_constant(std::ostream& os, const array& x) {
return print_float_constant<float16_t>(os, x); return print_float_constant<float16_t>(os, x);
case bfloat16: case bfloat16:
return print_float_constant<bfloat16_t>(os, x); return print_float_constant<bfloat16_t>(os, x);
case float64:
return print_float_constant<double>(os, x);
case complex64: case complex64:
return print_complex_constant<complex64_t>(os, x); return print_complex_constant<complex64_t>(os, x);
case int8: case int8:
@@ -52,8 +51,6 @@ std::string get_type_string(Dtype d) {
return "float16_t"; return "float16_t";
case bfloat16: case bfloat16:
return "bfloat16_t"; return "bfloat16_t";
case float64:
return "double";
case complex64: case complex64:
return "complex64_t"; return "complex64_t";
case bool_: case bool_:
@@ -82,6 +79,55 @@ std::string get_type_string(Dtype d) {
} }
} }
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
bool compiled_check_contiguity( bool compiled_check_contiguity(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const Shape& shape) { const Shape& shape) {
@@ -113,7 +159,8 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs( void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant, const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous) { bool contiguous) {
if (contiguous) { if (contiguous) {
int o = 0; int o = 0;
@@ -128,7 +175,8 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Not a constant // - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && is_constant(i)) { in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
@@ -156,7 +204,7 @@ void compiled_allocate_outputs(
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
is_constant(i)) { constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o].copy_shared_buffer( outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size()); in, outputs[o].strides(), in.flags(), in.data_size());
o++; o++;
@@ -168,74 +216,4 @@ void compiled_allocate_outputs(
} }
} }
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant) {
const Shape& shape = out.shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
if (contiguous) {
return {true, shape, {}};
}
std::vector<Strides> strides_vec{out.strides()};
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants.
if (is_constant(i)) {
continue;
}
// Skip scalar inputs.
const auto& x = inputs[i];
if (is_scalar(x)) {
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
size_t j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides_vec.push_back(std::move(xstrides));
}
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
}
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
size_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
size_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
return max_size > UINT32_MAX;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,8 +1,9 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <functional>
#include <iomanip> #include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -13,17 +14,19 @@ inline bool is_static_cast(const Primitive& p) {
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
} }
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d); std::string get_type_string(Dtype d);
template <typename T> template <typename T>
void print_float_constant(std::ostream& os, const array& x) { void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision(); auto old_precision = os.precision();
if constexpr (std::is_same_v<T, double>) { os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
os << std::setprecision(std::numeric_limits<double>::digits10 + 1); << x.item<T>() << std::setprecision(old_precision);
} else {
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
}
os << x.item<T>() << std::setprecision(old_precision);
} }
template <typename T> template <typename T>
@@ -57,19 +60,8 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs( void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant, const std::vector<array>& inputs_,
bool contiguous); const std::unordered_set<uintptr_t>& constant_ids_,
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant);
// Return whether the kernel should use large index.
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous); bool contiguous);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,7 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/common/utils.h" #include "mlx/array.h"
namespace mlx::core { namespace mlx::core {
@@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types // If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output. // have the same size, then the input buffer can hold the output.
if (is_donatable(in, out)) { if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
return true; return true;
} else { } else {

View File

@@ -99,10 +99,6 @@ inline std::pair<int, int> decompose_hadamard(int n) {
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
} }
} }
if (n > (1 << 26)) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where k <= 26");
}
return {n, m}; return {n, m};
} }

View File

@@ -1,67 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
#include <sstream>
namespace mlx::core {
inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto a_batch_strides = batch_strides[0];
auto b_batch_strides = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
a_batch_strides.push_back(0);
b_batch_strides.push_back(0);
}
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
}
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
} // namespace mlx::core

View File

@@ -5,9 +5,11 @@
namespace mlx::core { namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape, const array& x,
Strides strides,
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) { for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i]; int a = axes[i];
shape.erase(shape.begin() + a); shape.erase(shape.begin() + a);
@@ -17,15 +19,6 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides); return std::make_pair(shape, strides);
} }
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) { ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything // The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() && if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -51,9 +51,5 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x, const array& x,
const std::vector<int>& axes); const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -11,8 +11,6 @@ namespace mlx::core {
enum class TernaryOpType { enum class TernaryOpType {
ScalarScalarScalar, ScalarScalarScalar,
VectorVectorVector, VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General, General,
}; };
@@ -27,14 +25,6 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous && (a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) { c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector; topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else { } else {
topt = TernaryOpType::General; topt = TernaryOpType::General;
} }
@@ -69,8 +59,6 @@ inline void set_ternary_op_output_data(
b.flags()); b.flags());
} }
break; break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General: case TernaryOpType::General:
// Try to donate an input which is row_contiguous // Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) || if (!((a.flags().row_contiguous && maybe_donate(a)) ||

View File

@@ -1,26 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
} // namespace mlx::core

View File

@@ -1,22 +1,9 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
namespace mlx::core { namespace mlx::core {
std::filesystem::path current_binary_dir() {
static std::filesystem::path binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path();
}();
return binary_dir;
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims( std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape, const Shape& shape,
const std::vector<Strides>& strides, const std::vector<Strides>& strides,
@@ -114,118 +101,4 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap); return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
} }
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == pow2) {
break;
}
}
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
}
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
// Dims with strides of 0 are ignored as they
// correspond to broadcasted dimensions
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor) {
// Compute the 2d grid dimensions such that the total size of the grid is
// divided by divisor.
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
// No need to add this shape we can just remove it from the divisor.
if (divisor % shape[i] == 0) {
divisor /= shape[i];
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
if (divisor > 1) {
if (grid_x % divisor == 0) {
grid_x /= divisor;
divisor = 1;
} else if (grid_y % divisor == 0) {
grid_y /= divisor;
divisor = 1;
}
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
if (divisor > 1) {
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
auto gx = (dim0 + bx - 1) / bx;
auto gy = (dim1 + by - 1) / by;
auto gz = (dim2 + bz - 1) / bz;
return std::make_pair(
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,17 +2,12 @@
#pragma once #pragma once
#include <filesystem>
#include <tuple>
#include <vector> #include <vector>
#include "mlx/array.h" #include "mlx/array.h"
namespace mlx::core { namespace mlx::core {
// Return the directory that contains current shared library.
std::filesystem::path current_binary_dir();
inline int64_t inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) { elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int64_t loc = 0; int64_t loc = 0;
@@ -75,31 +70,6 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
const array& a, const array& a,
int64_t size_cap = std::numeric_limits<int32_t>::max()); int64_t size_cap = std::numeric_limits<int32_t>::max());
// Compute the thread block dimensions which fit the given
// input dimensions.
// - The thread block dimensions will be powers of two
// - The thread block size will be less than 2^pow2
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
// Computes a 2D grid where each element is < UINT_MAX
// Assumes:
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
// - shape and strides correspond to a contiguous (no holes) but
// possibly broadcasted array
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
// Same as above but we do an implicit division with divisor.
// Basically, equivalent to factorizing
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor);
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
struct ContiguousIterator { struct ContiguousIterator {
inline void step() { inline void step() {
int dims = shape_.size(); int dims = shape_.size();
@@ -195,11 +165,4 @@ void shared_buffer_reshape(
const array& in, const array& in,
const Strides& out_strides, const Strides& out_strides,
array& out); array& out);
template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -40,13 +40,11 @@ add_dependencies(mlx cpu_compiled_preamble)
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp

View File

@@ -14,8 +14,10 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) { void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis]; auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis]; auto axis_stride = in.strides()[axis];
Strides strides = remove_index(in.strides(), axis); Strides strides = in.strides();
Shape shape = remove_index(in.shape(), axis); Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
auto in_ptr = in.data<InT>(); auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>(); auto out_ptr = out.data<uint32_t>();

View File

@@ -1,11 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/available.h"
namespace mlx::core::cpu {
bool is_available() {
return true;
}
} // namespace mlx::core::cpu

View File

@@ -1,9 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cpu {
bool is_available();
} // namespace mlx::core::cpu

View File

@@ -172,12 +172,9 @@ void binary_float(
case bfloat16: case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt); binary_op<bfloat16_t, Op>(a, b, out, bopt);
break; break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[binary_float] Only supports floating point types."); "[binary_float] Only supports non-complex floating point types.");
} }
}); });
} }

View File

@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// The decomposition is computed in place, so just copy the input to the // The decomposition is computed in place, so just copy the input to the
// output. // output.
copy_cpu( copy(
a, a,
factor, factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -15,7 +15,6 @@
#include "mlx/backend/cpu/jit_compiler.h" #include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/graph_utils.h" #include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core { namespace mlx::core {
@@ -41,10 +40,7 @@ struct CompilerCache {
std::shared_mutex mtx; std::shared_mutex mtx;
}; };
static CompilerCache& cache() { static CompilerCache cache{};
static CompilerCache cache_;
return cache_;
};
// GPU compile is always available if the GPU is available and since we are in // GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available. // this file CPU compile is also available.
@@ -60,16 +56,14 @@ void* compile(
const std::string& kernel_name, const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) { const std::function<std::string(void)>& source_builder) {
{ {
std::shared_lock lock(cache().mtx); std::shared_lock lock(cache.mtx);
if (auto it = cache().kernels.find(kernel_name); if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
it != cache().kernels.end()) {
return it->second; return it->second;
} }
} }
std::unique_lock lock(cache().mtx); std::unique_lock lock(cache.mtx);
if (auto it = cache().kernels.find(kernel_name); if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
it != cache().kernels.end()) {
return it->second; return it->second;
} }
std::string source_code = source_builder(); std::string source_code = source_builder();
@@ -95,11 +89,7 @@ void* compile(
kernel_file_name = kernel_name; kernel_file_name = kernel_name;
} }
auto output_dir = auto output_dir = std::filesystem::temp_directory_path();
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
std::string shared_lib_name = "lib" + kernel_file_name + ".so"; std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string(); auto shared_lib_path = (output_dir / shared_lib_name).string();
@@ -130,10 +120,10 @@ void* compile(
} }
// load library // load library
cache().libs.emplace_back(shared_lib_path); cache.libs.emplace_back(shared_lib_path);
// Load function // Load function
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str()); void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) { if (!fun) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function " msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -141,7 +131,7 @@ void* compile(
<< dlerror(); << dlerror();
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
cache().kernels.insert({kernel_name, fun}); cache.kernels.insert({kernel_name, fun});
return fun; return fun;
} }
@@ -151,9 +141,18 @@ inline void build_kernel(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
const std::vector<array>& tape, const std::vector<array>& tape,
const std::function<bool(size_t)>& is_constant, const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous, bool contiguous,
int ndim) { int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer; NodeNamer namer;
#ifdef _MSC_VER #ifdef _MSC_VER
@@ -162,28 +161,25 @@ inline void build_kernel(
#endif #endif
// Start the kernel // Start the kernel
os << "void " << kernel_name os << "void " << kernel_name << "(void** args) {" << std::endl;
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
// Add the input arguments // Add the input arguments
int cnt = 0; int cnt = 0;
int strides_index = 1; for (auto& x : inputs) {
for (size_t i = 0; i < inputs.size(); ++i) { auto& xname = namer.get_name(x);
// Skip constants from the input list // Skip constants from the input list
if (is_constant(i)) { if (is_constant(x)) {
continue; continue;
} }
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
auto tstr = get_type_string(x.dtype()); auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl; << "];" << std::endl;
// Scalars and contiguous need no strides // Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) { if (!is_scalar(x) && !contiguous) {
os << " const int64_t* " << xname << "_strides = strides[" os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< strides_index++ << "];" << std::endl; << "];" << std::endl;
} }
} }
@@ -193,8 +189,10 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl; << "*)args[" << cnt++ << "];" << std::endl;
} }
// Add output size // Add output strides and shape to extract the indices.
if (contiguous) { if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
} }
@@ -208,11 +206,10 @@ inline void build_kernel(
} }
// Read the inputs in tmps // Read the inputs in tmps
for (size_t i = 0; i < inputs.size(); ++i) { for (auto& x : inputs) {
const auto& x = inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (is_constant(i)) { if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x); print_constant(os, x);
os << ";" << std::endl; os << ";" << std::endl;
@@ -236,7 +233,7 @@ inline void build_kernel(
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl; << namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else { } else {
os << x.primitive().name(); x.primitive().print(os);
os << "()("; os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
@@ -262,9 +259,8 @@ inline void build_kernel(
} else { } else {
for (int d = ndim - 1; d >= 0; --d) { for (int d = ndim - 1; d >= 0; --d) {
// Update pointers // Update pointers
for (size_t i = 0; i < inputs.size(); ++i) { for (auto& x : inputs) {
const auto& x = inputs[i]; if (is_constant(x) || is_scalar(x)) {
if (is_constant(i) || is_scalar(x)) {
continue; continue;
} }
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
@@ -286,33 +282,65 @@ inline void build_kernel(
void Compiled::eval_cpu( void Compiled::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
// Collapse contiguous dims to route to a faster kernel if possible. Also // Handle all broadcasting and collect function input arguments
// handle all broadcasting.
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Collect function input arguments.
std::vector<void*> args; std::vector<void*> args;
for (size_t i = 0; i < inputs.size(); ++i) { std::vector<std::vector<size_t>> strides;
if (is_constant_(i)) { for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue; continue;
} }
const auto& x = inputs[i]; auto& x = inputs[i];
encoder.set_input_array(x); encoder.set_input_array(x);
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
} }
// Get the kernel name from the lib // Get the kernel name from the lib
int ndim = shape.size(); int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) { if (!contiguous) {
kernel_name += std::to_string(ndim); kernel_name += std::to_string(shape.size());
} }
// Get the function // Get the function
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() { auto fn_ptr = compile(kernel_name, [&]() {
std::ostringstream kernel; std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl; kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl; kernel << "extern \"C\" {" << std::endl;
@@ -322,7 +350,7 @@ void Compiled::eval_cpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
is_constant_, constant_ids_,
contiguous, contiguous,
ndim); ndim);
// Close extern "C" // Close extern "C"
@@ -330,26 +358,26 @@ void Compiled::eval_cpu(
return kernel.str(); return kernel.str();
}); });
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous);
for (auto& x : outputs) { for (auto& x : outputs) {
args.push_back(x.data<void>()); args.push_back(x.data<void>());
encoder.set_output_array(x); encoder.set_output_array(x);
} }
if (contiguous) { Shape out_shape;
if (!contiguous) {
out_shape = outputs[0].shape();
args.push_back((void*)out_shape.data());
} else {
args.push_back((void*)outputs[0].data_size()); args.push_back((void*)outputs[0].data_size());
} }
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr); auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch([fun, encoder.dispatch(
args = std::move(args), [fun,
strides = std::move(strides), args = std::move(args),
shape = std::move(shape)]() mutable { strides = std::move(strides),
SmallVector<int64_t*> strides_ptrs; out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
} }
} // namespace mlx::core } // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@@ -295,11 +295,7 @@ inline void copy_inplace_dispatch(
} // namespace } // namespace
void copy_cpu_inplace( void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
const array& src,
array& dst,
CopyType ctype,
Stream stream) {
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src); encoder.set_input_array(src);
encoder.set_output_array(dst); encoder.set_output_array(dst);
@@ -309,7 +305,7 @@ void copy_cpu_inplace(
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); }); ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
} }
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) { void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
bool donated = set_copy_output_data(src, dst, ctype); bool donated = set_copy_output_data(src, dst, ctype);
if (donated && src.dtype() == dst.dtype()) { if (donated && src.dtype() == dst.dtype()) {
// If the output has the same type as the input then there is nothing to // If the output has the same type as the input then there is nothing to
@@ -319,10 +315,10 @@ void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
if (ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General; ctype = CopyType::General;
} }
copy_cpu_inplace(src, dst, ctype, stream); copy_inplace(src, dst, ctype, stream);
} }
void copy_cpu_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const Shape& data_shape, const Shape& data_shape,
@@ -377,10 +373,4 @@ void copy_cpu_inplace(
}); });
} }
array contiguous_copy_cpu(const array& arr, Stream stream) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return arr_copy;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -10,14 +10,10 @@
namespace mlx::core { namespace mlx::core {
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream); void copy(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace( void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
const array& src,
array& dst,
CopyType ctype,
Stream stream);
void copy_cpu_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const Shape& data_shape, const Shape& data_shape,
@@ -30,7 +26,4 @@ void copy_cpu_inplace(
const std::optional<array>& dynamic_i_offset = std::nullopt, const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt); const std::optional<array>& dynamic_o_offset = std::nullopt);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_cpu(const array& arr, Stream stream);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -13,7 +13,9 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return {arr, false}; return {arr, false};
} else { } else {
return {contiguous_copy_cpu(arr, stream), true}; array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
} }
}; };
@@ -32,7 +34,8 @@ void AllReduce::eval_cpu(
} }
return in; return in;
} else { } else {
array arr_copy = contiguous_copy_cpu(in, s); array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy); out.copy_shared_buffer(arr_copy);
return arr_copy; return arr_copy;
} }

View File

@@ -1,174 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
void eig_impl(
array& a,
array& vectors,
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
OT* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>();
}
encoder.dispatch([a_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N;
}
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
});
encoder.add_temporary(a);
}
} // namespace
void Eig::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy_cpu(
a,
a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
values.set_data(allocator::malloc(values.nbytes()));
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.set_data(
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
}
switch (a.dtype()) {
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
}
}
} // namespace mlx::core

View File

@@ -12,133 +12,6 @@ namespace mlx::core {
namespace { namespace {
template <typename T, class Enable = void>
struct EighWork {};
template <typename T>
struct EighWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
char jobz;
char uplo;
int N;
int lwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, T* values) {
syevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<int*>(buffers[1].buffer.raw_ptr()),
&liwork,
&info);
}
};
template <>
struct EighWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
char jobz;
char uplo;
int N;
int lwork;
int lrwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
T work;
R rwork;
int iwork;
heevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&rwork,
&lrwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work.real());
lrwork = static_cast<int>(rwork);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, R* values) {
heevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&lrwork,
static_cast<int*>(buffers[2].buffer.raw_ptr()),
&liwork,
&info);
if (jobz == 'V') {
// We have pre-transposed the vectors but we also must conjugate them
// when they are complex.
//
// We could vectorize this but it is so fast in comparison to heevd that
// it doesn't really matter.
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
*vectors = std::conj(*vectors);
vectors++;
}
}
}
}
};
template <typename T> template <typename T>
void eigh_impl( void eigh_impl(
array& vectors, array& vectors,
@@ -146,10 +19,8 @@ void eigh_impl(
const std::string& uplo, const std::string& uplo,
bool compute_eigenvectors, bool compute_eigenvectors,
Stream stream) { Stream stream) {
using R = typename EighWork<T>::R;
auto vec_ptr = vectors.data<T>(); auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<R>(); auto eig_ptr = values.data<T>();
char jobz = compute_eigenvectors ? 'V' : 'N'; char jobz = compute_eigenvectors ? 'V' : 'N';
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
@@ -162,17 +33,49 @@ void eigh_impl(
N = vectors.shape(-1), N = vectors.shape(-1),
size = vectors.size()]() mutable { size = vectors.size()]() mutable {
// Work query // Work query
EighWork<T> work(jobz, uplo, N); int lwork = -1;
int liwork = -1;
int info;
{
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
}
// Work loop auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) { for (size_t i = 0; i < size / (N * N); ++i) {
work.run(vec_ptr, eig_ptr); syevd<T>(
&jobz,
&uplo,
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
vec_ptr += N * N; vec_ptr += N * N;
eig_ptr += N; eig_ptr += N;
if (work.info != 0) { if (info != 0) {
std::stringstream msg; std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< work.info; << info;
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
} }
@@ -196,7 +99,7 @@ void Eigh::eval_cpu(
values.set_data(allocator::malloc(values.nbytes())); values.set_data(allocator::malloc(values.nbytes()));
copy_cpu( copy(
a, a,
vectors, vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
@@ -228,10 +131,6 @@ void Eigh::eval_cpu(
eigh_impl<double>( eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream()); vectors, values, uplo_, compute_eigenvectors_, stream());
break; break;
case complex64:
eigh_impl<std::complex<float>>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[Eigh::eval_cpu] only supports float32 or float64."); "[Eigh::eval_cpu] only supports float32 or float64.");

View File

@@ -88,47 +88,4 @@ void matmul<double>(
} }
} }
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.flags().row_contiguous && in.is_donatable()) { if (in.flags().row_contiguous && in.is_donatable()) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy_cpu( copy(
in, in,
out, out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General, in.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -257,11 +257,15 @@ void gather_axis(
const array& ind, const array& ind,
array& out, array& out,
const int axis) { const int axis) {
auto shape = remove_index(ind.shape(), axis); auto strides = ind.strides();
ContiguousIterator ind_it( strides.erase(strides.begin() + axis);
shape, remove_index(ind.strides(), axis), src.ndim() - 1); auto shape = ind.shape();
ContiguousIterator src_it( shape.erase(shape.begin() + axis);
shape, remove_index(src.strides(), axis), src.ndim() - 1); ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>(); auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>(); auto src_ptr = src.data<T>();
@@ -517,7 +521,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out) // Copy src into out (copy allocates memory for out)
auto ctype = auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General; src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream()); copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds; std::vector<array> inds;
@@ -581,11 +585,15 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
template <typename T, typename IdxT, typename OpT> template <typename T, typename IdxT, typename OpT>
void scatter_axis(array& out, const array idx, const array& upd, int axis) { void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto shape = remove_index(idx.shape(), axis); auto strides = idx.strides();
ContiguousIterator idx_it( strides.erase(strides.begin() + axis);
shape, remove_index(idx.strides(), axis), upd.ndim() - 1); auto shape = idx.shape();
ContiguousIterator upd_it( shape.erase(shape.begin() + axis);
shape, remove_index(upd.strides(), axis), upd.ndim() - 1); ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>(); auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>(); auto upd_ptr = upd.data<T>();
@@ -686,7 +694,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out) // Copy src into out (copy allocates memory for out)
auto ctype = auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General; src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream()); copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx); encoder.set_input_array(idx);

View File

@@ -115,7 +115,7 @@ void inverse_impl(
// (A⁻¹)ᵀ = (Aᵀ)⁻¹ // (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output. // The inverse is computed in place, so just copy the input to the output.
copy_cpu( copy(
a, a,
inv, inv,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/cpu/jit_compiler.h" #include "mlx/backend/cpu/jit_compiler.h"
#include <algorithm>
#include <sstream> #include <sstream>
#include <vector> #include <vector>

View File

@@ -2,14 +2,14 @@
#pragma once #pragma once
// Required for Visual Studio.
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
#ifdef _MSC_VER
#include <complex> #include <complex>
#define LAPACK_COMPLEX_CUSTOM #define LAPACK_COMPLEX_CUSTOM
#define lapack_complex_float std::complex<float> #define lapack_complex_float std::complex<float>
#define lapack_complex_double std::complex<double> #define lapack_complex_double std::complex<double>
#define lapack_complex_float_real(z) ((z).real()) #endif
#define lapack_complex_float_imag(z) ((z).imag())
#define lapack_complex_double_real(z) ((z).real())
#define lapack_complex_double_imag(z) ((z).imag())
#ifdef MLX_USE_ACCELERATE #ifdef MLX_USE_ACCELERATE
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
@@ -32,7 +32,7 @@
#endif #endif
#define INSTANTIATE_LAPACK_REAL(FUNC) \ #define INSTANTIATE_LAPACK_TYPES(FUNC) \
template <typename T, typename... Args> \ template <typename T, typename... Args> \
void FUNC(Args... args) { \ void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \ if constexpr (std::is_same_v<T, float>) { \
@@ -42,24 +42,11 @@
} \ } \
} }
INSTANTIATE_LAPACK_REAL(geqrf) INSTANTIATE_LAPACK_TYPES(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr) INSTANTIATE_LAPACK_TYPES(orgqr)
INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_TYPES(syevd)
INSTANTIATE_LAPACK_REAL(geev) INSTANTIATE_LAPACK_TYPES(potrf)
INSTANTIATE_LAPACK_REAL(potrf) INSTANTIATE_LAPACK_TYPES(gesvdx)
INSTANTIATE_LAPACK_REAL(gesdd) INSTANTIATE_LAPACK_TYPES(getrf)
INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_TYPES(getri)
INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_TYPES(trtri)
INSTANTIATE_LAPACK_REAL(trtri)
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_COMPLEX(heevd)

View File

@@ -87,7 +87,8 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x; return x;
} else { } else {
array x_copy = contiguous_copy_cpu(x, s); auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy); encoder.add_temporary(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -31,7 +31,7 @@ void luf_impl(
strides[ndim - 1] = M; strides[ndim - 1] = M;
strides[ndim - 2] = 1; strides[ndim - 2] = 1;
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags); lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
copy_cpu_inplace( copy_inplace(
a, a,
lu, lu,
a.shape(), a.shape(),

View File

@@ -6,7 +6,6 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -53,58 +52,6 @@ inline void mask_matrix(
} }
} }
template <typename T>
inline void segmented_mm(
const T* a,
const T* b,
const uint32_t* segments,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides,
size_t num_segments,
const Shape& segments_shape,
const Strides& segments_strides) {
int ndim = a_shape.size();
Shape a_copy = a_shape;
Shape b_copy = b_shape;
int32_t M = a_copy[ndim - 2];
int32_t N = b_copy[ndim - 1];
for (int i = 0; i < num_segments; i++) {
uint32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
if (k_end <= k_start) {
std::fill_n(out + i * M * N, M * N, T(0));
continue;
}
a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start;
matmul<T>(
a + k_start * a_strides[ndim - 1],
b + k_start * b_strides[ndim - 2],
out + i * M * N,
a_transposed,
b_transposed,
lda,
ldb,
N,
1.0,
0.0,
1,
a_copy,
a_strides,
b_copy,
b_strides);
}
}
} // namespace } // namespace
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) { void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -124,20 +71,21 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (!expand_all && stx == arr.shape(-1) && sty == 1) { if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) { if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::Vector, s); copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(false, stx, arr_copy, true); return std::make_tuple(false, stx, arr_copy, true);
} }
return std::make_tuple(false, stx, arr, false); return std::make_tuple(false, stx, arr, false);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) { } else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) { if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::Vector, s); copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(true, sty, arr_copy, true); return std::make_tuple(true, sty, arr_copy, true);
} }
return std::make_tuple(true, sty, arr, false); return std::make_tuple(true, sty, arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
array arr_copy = contiguous_copy_cpu(arr, s);
return std::make_tuple(false, stx, arr_copy, true); return std::make_tuple(false, stx, arr_copy, true);
} }
}; };
@@ -385,7 +333,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s); copy(arr, temps.back(), CopyType::General, s);
int64_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back()); return std::make_tuple(false, stx, temps.back());
} }
@@ -489,121 +437,4 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
} }
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cpu::get_command_encoder(stream());
auto check_transpose = [&s, &encoder](const array& x) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
if (stx == x.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, x);
} else if (stx == 1 && sty == x.shape(-2)) {
return std::make_tuple(true, sty, x);
} else {
array xc(x.shape(), x.dtype(), nullptr, {});
copy_cpu(x, xc, CopyType::General, s);
encoder.add_temporary(xc);
int64_t stx = x.shape(-1);
return std::make_tuple(false, stx, xc);
}
};
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
auto& segments = inputs[2];
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(segments);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
segments = array::unsafe_weak_copy(segments),
out_ptr = out.data<void>(),
a_transposed = a_transposed,
b_transposed = b_transposed,
lda = lda,
ldb = ldb]() {
switch (a.dtype()) {
case float64:
segmented_mm<double>(
a.data<double>(),
b.data<double>(),
segments.data<uint32_t>(),
static_cast<double*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float32:
segmented_mm<float>(
a.data<float>(),
b.data<float>(),
segments.data<uint32_t>(),
static_cast<float*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float16:
segmented_mm<float16_t>(
a.data<float16_t>(),
b.data<float16_t>(),
segments.data<uint32_t>(),
static_cast<float16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case bfloat16:
segmented_mm<bfloat16_t>(
a.data<bfloat16_t>(),
b.data<bfloat16_t>(),
segments.data<uint32_t>(),
static_cast<bfloat16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
default:
throw std::invalid_argument(
"Segmented mm supports only real float types.");
}
});
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -81,7 +81,7 @@ void matmul_general(
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, stream); copy(arr, temps.back(), CopyType::General, stream);
stx = arr.shape(-1); stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back()); return std::make_tuple(false, stx, temps.back());
} }
@@ -108,9 +108,6 @@ void matmul_general(
} else if (out.dtype() == float64) { } else if (out.dtype() == float64) {
matmul_dispatch<double>( matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else { } else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
} }
@@ -131,9 +128,9 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) { if (out.dtype() != float32) {
out.set_data(allocator::malloc(out.nbytes())); throw std::runtime_error(
return; "[AddMM::eval_cpu] Currently only supports float32.");
} }
// Fill output with C // Fill output with C
@@ -141,10 +138,8 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
CopyType ctype = c.data_size() == 1 CopyType ctype = c.data_size() == 1
? CopyType::Scalar ? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream()); copy(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
} }

View File

@@ -22,7 +22,7 @@ void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream()); copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
} }
@@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(in, out, ctype, stream()); copy(in, out, ctype, stream());
} }
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) { void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t data_offset = strides[axis_] * sizes[i]; size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer( out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset); out, strides, flags, out_slice.size(), data_offset);
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
} }
} }
@@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
(allow_col_major_ && in.flags().col_contiguous))) { (allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy_cpu(in, out, CopyType::General, stream()); copy(in, out, CopyType::General, stream());
} }
} }
@@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
} else { } else {
ctype = CopyType::General; ctype = CopyType::General;
} }
copy_cpu(in, out, ctype, stream()); copy(in, out, ctype, stream());
} }
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) { void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val // Fill output with val
copy_cpu(val, out, CopyType::Scalar, stream()); copy(val, out, CopyType::Scalar, stream());
// Find offset for start of input values // Find offset for start of input values
size_t data_offset = 0; size_t data_offset = 0;
@@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out.strides(), out.flags(), out_slice.size(), data_offset); out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice // Copy input values into the slice
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
} }
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) { void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -340,7 +340,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] = auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream()); compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_cpu_inplace( copy_inplace(
/* const array& src = */ in, /* const array& src = */ in,
/* array& dst = */ out, /* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(), /* const Shape& data_shape = */ out.shape(),
@@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu(
auto ctype = in.flags().contiguous && in.size() == in.data_size() auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector ? CopyType::Vector
: CopyType::General; : CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto [out_offset, donated] = auto [out_offset, donated] =
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream()); compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
copy_cpu_inplace( copy_inplace(
/* const array& src = */ upd, /* const array& src = */ upd,
/* array& dst = */ out, /* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(), /* const std::vector<int>& data_shape = */ upd.shape(),
@@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
auto ctype = in.flags().contiguous && in.size() == in.data_size() auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector ? CopyType::Vector
: CopyType::General; : CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made // Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_); prepare_slice(out, start_indices_, strides_);
// Do copy // Do copy
copy_cpu_inplace( copy_inplace(
/* const array& src = */ upd, /* const array& src = */ upd,
/* array& dst = */ out, /* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(), /* const std::vector<int>& data_shape = */ upd.shape(),
@@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) { if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {}); auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in); in_tmp.copy_shared_buffer(in);
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream()); copy_inplace(in_tmp, tmp, CopyType::General, stream());
} else { } else {
copy_cpu_inplace(in, tmp, CopyType::General, stream()); copy_inplace(in, tmp, CopyType::General, stream());
} }
auto flags = out.flags(); auto flags = out.flags();

View File

@@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
strides[in.ndim() - 2] = 1; strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M; strides[in.ndim() - 1] = M;
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags); in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream); copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc(q.nbytes())); q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes())); r.set_data(allocator::malloc(r.nbytes()));

View File

@@ -1,5 +1,7 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@@ -11,47 +13,9 @@ namespace mlx::core {
namespace { namespace {
const static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
auto power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int bits> template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) { void extract_bits(const uint8_t* w_in, T* w_out) {
static_assert(bits == 3 || bits == 5 || bits == 6); assert(bits == 3 || bits == 6);
if (bits == 3) { if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7); w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3); w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
@@ -61,16 +25,6 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1)); w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2); w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5); w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 5) {
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
} else if (bits == 6) { } else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f); w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] = w_out[1] =
@@ -92,8 +46,8 @@ void _qmm(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = get_pack_factor(bits, 8); constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = get_bytes_per_pack(bits); constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
@@ -111,7 +65,7 @@ void _qmm(
T scale = *scales_local++; T scale = *scales_local++;
T bias = *biases_local++; T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) { for (int ng = 0; ng < packs_in_group; ng++) {
if constexpr (bits == 3 || bits == 5 || bits == 6) { if (bits == 3 || bits == 6) {
T wl[pack_factor]; T wl[pack_factor];
extract_bits<T, bits>(w_local, wl); extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
@@ -150,9 +104,8 @@ void _qmm_t(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int pack_factor = get_pack_factor(bits, 8); constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
@@ -168,7 +121,7 @@ void _qmm_t(
T bias = *biases_local++; T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) { for (int kw = 0; kw < packs_in_group; kw++) {
if constexpr (bits == 3 || bits == 5 || bits == 6) { if (bits == 3 || bits == 6) {
T wl[pack_factor]; T wl[pack_factor];
extract_bits<T, bits>(w_local, wl); extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
@@ -351,10 +304,6 @@ void _qmm_dispatch_typed(
_qmm_dispatch_group<T, 4>( _qmm_dispatch_group<T, 4>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break; break;
case 5:
_qmm_dispatch_group<T, 5>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 6: case 6:
_qmm_dispatch_group<T, 6>( _qmm_dispatch_group<T, 6>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
@@ -434,231 +383,6 @@ void _qmm_dispatch(
} }
} }
template <typename T>
void mxfp4_qmm(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
for (int ng = 0; ng < packs_in_group; ng++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
}
}
result += N;
}
}
template <typename T>
void mxfp4_qmm_t(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
T gsum = 0;
for (int kw = 0; kw < packs_in_group; kw++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
sum += scale * gsum;
}
*result = sum;
result++;
}
x += K;
}
}
template <int S>
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
if constexpr (S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & 0xf;
simd::Simd<float, S> w_out;
for (int i = 0; i < S; ++i) {
w_out[i] = MXFP4_LUT[wi[i]];
}
return w_out;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
}
template <typename T>
void mxfp4_qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = 32 / 4;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
simd::Simd<float, S> g_acc(0);
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
// Extract bits
auto wf = mxfp4_extract_bits_simd<S>(w_local);
w_local += packs_per_simd;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
g_acc = g_acc + x_simd * wf;
x_local += S;
}
acc = acc + scale * g_acc;
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T>
void mxfp4_qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (simd::max_size<T> % 8 == 0) {
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
} else {
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
}
} else {
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
}
}
template <typename T>
void mxfp4_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
for (int i = 0; i < batch_size; i++) {
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
switch (x.dtype()) {
case bfloat16:
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
break;
case float16:
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
break;
case float32:
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T> template <typename T>
void _bs_qmm_dispatch_typed( void _bs_qmm_dispatch_typed(
array& out, array& out,
@@ -765,198 +489,115 @@ void _bs_qmm_dispatch(
} }
} }
template <typename T>
void mxfp4_bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
switch (x.dtype()) {
case float32:
mxfp4_bs_qmm_dispatch_typed<float>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case float16:
mxfp4_bs_qmm_dispatch_typed<float16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case bfloat16:
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace } // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
auto& scales_pre = inputs[2]; auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& encoder = cpu::get_command_encoder(stream()); std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) { auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return arr; return arr;
} else { } else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, arr_cpy, CopyType::General, s); copy(arr, temps.back(), CopyType::General, s);
encoder.add_temporary(arr_cpy); return temps.back();
return arr_cpy;
} }
}; };
auto x = ensure_row_contiguous(x_pre); auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre); auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre); auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(scales); encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out); encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) { encoder.dispatch([out = array::unsafe_weak_copy(out),
auto biases = ensure_row_contiguous(inputs[3]); x = array::unsafe_weak_copy(x),
encoder.set_input_array(biases); w = array::unsafe_weak_copy(w),
encoder.dispatch([out = array::unsafe_weak_copy(out), scales = array::unsafe_weak_copy(scales),
x = array::unsafe_weak_copy(x), biases = array::unsafe_weak_copy(biases),
w = array::unsafe_weak_copy(w), group_size_ = group_size_,
scales = array::unsafe_weak_copy(scales), bits_ = bits_,
biases = array::unsafe_weak_copy(biases), transpose_ = transpose_]() mutable {
group_size_ = group_size_, _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
bits_ = bits_, });
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
transpose_ = transpose_]() mutable {
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
});
}
} }
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
auto& scales_pre = inputs[2]; auto& scales_pre = inputs[2];
auto& lhs_indices = inputs[inputs.size() - 2]; auto& biases_pre = inputs[3];
auto& rhs_indices = inputs[inputs.size() - 1]; auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto& encoder = cpu::get_command_encoder(stream()); std::vector<array> temps;
auto ensure_row_contiguous_last_dims = [s = stream(), auto ensure_row_contiguous_last_dims = [s = stream(),
&encoder](const array& arr) { &temps](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2]; auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1]; auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) { if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr; return arr;
} else { } else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, arr_cpy, CopyType::General, s); copy(arr, temps.back(), CopyType::General, s);
encoder.add_temporary(arr_cpy); return temps.back();
return arr_cpy;
} }
}; };
auto x = ensure_row_contiguous_last_dims(x_pre); auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre); auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(scales); encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices); encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices); encoder.set_input_array(rhs_indices);
encoder.set_output_array(out); encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) { encoder.dispatch([out = array::unsafe_weak_copy(out),
auto biases = ensure_row_contiguous_last_dims(inputs[3]); x = array::unsafe_weak_copy(x),
encoder.set_input_array(biases); w = array::unsafe_weak_copy(w),
encoder.dispatch([out = array::unsafe_weak_copy(out), scales = array::unsafe_weak_copy(scales),
x = array::unsafe_weak_copy(x), biases = array::unsafe_weak_copy(biases),
w = array::unsafe_weak_copy(w), lhs_indices = array::unsafe_weak_copy(lhs_indices),
scales = array::unsafe_weak_copy(scales), rhs_indices = array::unsafe_weak_copy(rhs_indices),
biases = array::unsafe_weak_copy(biases), group_size_ = group_size_,
lhs_indices = array::unsafe_weak_copy(lhs_indices), bits_ = bits_,
rhs_indices = array::unsafe_weak_copy(rhs_indices), transpose_ = transpose_]() mutable {
group_size_ = group_size_, _bs_qmm_dispatch(
bits_ = bits_, out,
transpose_ = transpose_]() mutable { x,
_bs_qmm_dispatch( w,
out, scales,
x, biases,
w, lhs_indices,
scales, rhs_indices,
biases, group_size_,
lhs_indices, bits_,
rhs_indices, transpose_);
group_size_, });
bits_,
transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
transpose_ = transpose_]() mutable {
mxfp4_bs_qmm_dispatch(
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
});
}
} }
template <typename T, typename U> template <typename T, typename U>
@@ -972,8 +613,9 @@ void quantize(
float eps = 1e-7; float eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits); bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = get_pack_factor(bits, 32); int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
int bytes_per_pack = get_bytes_per_pack(bits); // For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int; int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_size / group_size; size_t n_groups = w_size / group_size;
@@ -998,21 +640,15 @@ void quantize(
} }
size_t out_idx = i * int_per_group; size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint64_t out_el = 0; uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) { for (int k = 0; k < el_per_int; ++k) {
float w_el = w[w_idx + j * el_per_int + k]; float w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - bias) / scale); w_el = std::rint((w_el - bias) / scale);
w_el = std::min(std::max(w_el, 0.0f), n_bins); w_el = std::min(std::max(w_el, 0.0f), n_bins);
out_el |= static_cast<uint64_t>(w_el) << (k * bits); out_el |= static_cast<uint32_t>(w_el) << (k * bits);
} }
if (power_of_2_bits) { if (power_of_2_bits) {
out[out_idx + j] = out_el; out[out_idx + j] = out_el;
} else if (bits == 5) {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
} else { } else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff; out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8; out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
@@ -1040,14 +676,16 @@ void dispatch_quantize(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size()); w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
} }
void fast::Quantize::eval_cpu( void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
auto ensure_row_contiguous = [s = stream()](const array& arr) { auto ensure_row_contiguous = [s = stream()](const array& arr) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return std::make_pair(arr, false); return std::make_pair(arr, false);
} else { } else {
return std::make_pair(contiguous_copy_cpu(arr, s), true); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
} }
}; };
@@ -1099,7 +737,7 @@ void fast::Quantize::eval_cpu(
} }
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"[fast::Quantize::eval_cpu] Only supports floating point inputs"); "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
} }
}); });
} }

View File

@@ -325,15 +325,7 @@ struct MaxReduce {
}; };
template <int N, typename T> template <int N, typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) { T operator()(simd::Simd<T, N> x) {
return simd::max(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
return simd::max(x); return simd::max(x);
}; };
}; };
@@ -350,15 +342,7 @@ struct MinReduce {
}; };
template <int N, typename T> template <int N, typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) { T operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
return simd::min(x); return simd::min(x);
}; };
}; };
@@ -491,27 +475,19 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
case uint8: case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8: case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break; break;
case int16: case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break; break;
case int64: case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break; break;
case float16: case float16:
@@ -551,10 +527,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_); reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break; break;
case int8: case int8:
reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_); reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break; break;
case int16: case int16:
reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_); reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_); reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);

View File

@@ -250,8 +250,10 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// Ensure contiguity // Ensure contiguity
auto in = inputs[0]; auto in = inputs[0];
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
in = contiguous_copy_cpu(in, stream()); array arr_copy(in.shape(), in.dtype(), nullptr, {});
encoder.add_temporary(in); copy(in, arr_copy, CopyType::General, stream());
in = arr_copy;
encoder.add_temporary(arr_copy);
} }
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
@@ -328,8 +330,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_type_, in, out, axis_, reverse_, inclusive_); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case complex64: case complex64:
scan_dispatch<complex64_t, complex64_t>( throw std::runtime_error("Scan ops do not support complex types yet");
reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
} }
}); });

View File

@@ -9,7 +9,7 @@
#include "mlx/backend/cpu/simd/base_simd.h" #include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in simd/base_simd.h // There seems to be a bug in sims/base.h
// __XROS_2_0 is not defined, the expression evaluates // __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library // to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15 // higher than it should be even on macOS < 15
@@ -234,7 +234,6 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
template <typename MaskT, typename T1, typename T2, int N> template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) { Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
static_assert(std::is_same_v<MaskT, bool>);
if constexpr (sizeof(T1) == 1) { if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value)); return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) { } else if constexpr (sizeof(T1) == 2) {
@@ -252,13 +251,9 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
return asd::pow(base.value, exp.value); return asd::pow(base.value, exp.value);
} else { } else {
Simd<T, N> res = 1; Simd<T, N> res = 1;
// Raising an integer to a negative power is undefined while (any(exp)) {
if (any(exp < 0)) { res = select(exp & 1, res * base, res);
return 0; base = select(exp, base * base, base);
}
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > 0, base * base, base);
exp = exp >> 1; exp = exp >> 1;
} }
return res; return res;

View File

@@ -88,33 +88,12 @@ DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh) DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log1p(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto x = in.value.real();
auto y = in.value.imag();
auto zabs = std::abs(in.value);
auto theta = std::atan2(y, x + 1);
if (zabs < 0.5) {
auto r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return Simd<T, 1>{T{x, theta}};
}
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
} else {
auto z0 = std::hypot(x + 1, y);
return Simd<T, 1>{T{std::log(z0), theta}};
}
} else {
return Simd<T, 1>{std::log1p(in.value)};
}
}
template <typename T> template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) { Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) { if constexpr (is_complex<T>) {

View File

@@ -131,7 +131,8 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
return x; return x;
} else { } else {
array x_copy = contiguous_copy_cpu(x, s); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -8,25 +8,13 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
namespace { namespace {
// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
return true;
}
return a < b;
}
template <typename T> template <typename T>
struct StridedIterator { struct StridedIterator {
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
@@ -142,7 +130,7 @@ void sort(array& out, int axis) {
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed, nan_aware_less<T>); std::stable_sort(st, ed);
src_it.step(); src_it.step();
} }
} }
@@ -196,15 +184,6 @@ void argsort(const array& in, array& out, int axis) {
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
@@ -240,7 +219,7 @@ void partition(array& out, int axis, int kth) {
StridedIterator md(data_ptr, axis_stride, kth); StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed, nan_aware_less<T>); std::nth_element(st, md, ed);
} }
} }
@@ -297,15 +276,6 @@ void argpartition(const array& in, array& out, int axis, int kth) {
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
@@ -363,24 +333,45 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
int axis = axis_;
if (axis < 0) {
axis += in.ndim();
}
// Copy input to output // Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0) CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
? CopyType::Vector copy(in, out, ctype, stream());
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable { encoder.dispatch(
dispatch_all_types(out.dtype(), [&](auto type_tag) { [out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
sort<MLX_GET_TYPE(type_tag)>(out, axis); switch (out.dtype()) {
}); case bool_:
}); return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
} }
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) { void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -435,10 +426,8 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Copy input to output // Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0) CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
? CopyType::Vector copy(in, out, ctype, stream());
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out); encoder.set_output_array(out);

View File

@@ -31,7 +31,7 @@ void svd_impl(
// lapack clobbers the input, so we have to make a copy. // lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {}); array in(a.shape(), a.dtype(), nullptr, {});
copy_cpu( copy(
a, a,
in, in,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
@@ -81,7 +81,9 @@ void svd_impl(
// Vᵀ of shape N x N. (M x M in lapack). // Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M; const int ldvt = M;
auto jobz = (u_ptr) ? "A" : "N"; auto job_u = (u_ptr) ? "V" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned. // Will contain the number of singular values after the call has returned.
int ns = 0; int ns = 0;
@@ -89,20 +91,30 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not // Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack). // used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)}; auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1; static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info; int info;
// Compute workspace size. // Compute workspace size.
gesdd<T>( gesvdx<T>(
/* jobz = */ jobz, /* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ nullptr, /* a = */ nullptr,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr, /* s = */ nullptr,
/* u = */ nullptr, /* u = */ nullptr,
/* ldu = */ &ldu, /* ldu = */ &ldu,
@@ -124,13 +136,20 @@ void svd_impl(
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
gesdd<T>( gesvdx<T>(
/* jobz = */ jobz, /* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ in_ptr + M * N * i, /* a = */ in_ptr + M * N * i,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i, /* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U. // According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr, /* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
@@ -148,6 +167,13 @@ void svd_impl(
ss << "svd_impl: sgesvdx_ failed with code " << info; ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
} }
}); });
encoder.add_temporary(in); encoder.add_temporary(in);

View File

@@ -2,13 +2,32 @@
#pragma once #pragma once
#include "mlx/backend/common/unary.h" #include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
template <typename T, typename U = T, typename Op> template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, size_t shape, size_t stride) { void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) { for (size_t i = 0; i < shape; i += 1) {

View File

@@ -77,8 +77,7 @@ struct Real {
struct Sigmoid { struct Sigmoid {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x))); return 1.0f / (1.0f + simd::exp(-x));
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
} }
SINGLE() SINGLE()
}; };

View File

@@ -1,180 +0,0 @@
# Filename rules in cuda backend:
#
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
# * Device-only code should be put in device/ subdir.
# * Files in device/ subdir should not include files outside.
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
else()
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Embed kernel sources in binary for JIT compilation.
file(
GLOB MLX_JIT_SOURCES
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
add_custom_command(
OUTPUT gen/cuda_jit_sources.h
COMMAND
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
add_dependencies(mlx cuda_jit_sources)
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
# Enable defining device lambda functions.
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only.
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
# Explicitly pass this flag to suppress the warning, it is safe to set it to
# true but the warning wouldn't be suppressed.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
target_compile_options(
mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
endif()
# Suppress warning when building for compute capability 7 used by V100.
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
# and requires drivers released after CUDA 12.4.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native")
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
# Use fixed version of CCCL.
FetchContent_Declare(
cccl
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
nvtx3
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
GIT_TAG v3.1.1
GIT_SHALLOW TRUE
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(nvtx3)
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
# Make cuda runtime APIs available in non-cuda files.
find_package(CUDAToolkit REQUIRED)
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Use cublasLt.
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Use NVRTC and driver APIs.
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
# Use the frontend APIs of cuDNN.
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.14.0
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
FetchContent_MakeAvailable(cudnn)
target_link_libraries(mlx PRIVATE cudnn_frontend)
# Link with the actual cuDNN libraries.
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -1,270 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/utils.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
#include <unistd.h>
#include <cassert>
namespace mlx::core {
namespace cu {
constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() {
auto num_blocks = small_pool_size / small_block_size;
buffer_ = new Block[num_blocks];
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {
curr->next = buffer_ + i;
curr = curr->next;
}
curr->next = nullptr;
}
SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(data_));
delete[] buffer_;
}
CudaBuffer* SmallSizePool::malloc() {
if (next_free_ == nullptr) {
return nullptr;
}
Block* b = next_free_;
uint64_t i = next_free_ - buffer_;
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
return &b->buf;
}
void SmallSizePool::free(CudaBuffer* buf) {
auto b = reinterpret_cast<Block*>(buf);
b->next = next_free_;
next_free_ = b;
}
bool SmallSizePool::in_pool(CudaBuffer* buf) {
constexpr int num_blocks = (small_pool_size / small_block_size);
auto b = reinterpret_cast<Block*>(buf);
int64_t block_num = b - buffer_;
return block_num >= 0 && block_num < num_blocks;
}
CudaAllocator::CudaAllocator()
: buffer_cache_(
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
}
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
} else if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure try to reclaim memory from the cache.
int64_t mem_to_free =
get_active_memory() + get_cache_memory() + size - memory_limit_;
if (mem_to_free > 0) {
buffer_cache_.release_cached_buffers(mem_to_free);
}
// Try the scalar pool first
if (size <= small_block_size) {
buf = scalar_pool_.malloc();
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
}
lock.lock();
}
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf};
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return;
}
std::unique_lock lock(mutex_);
active_memory_ -= buf->size;
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
cuda_free(buf);
}
}
size_t CudaAllocator::size(Buffer buffer) const {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return 0;
}
return buf->size;
}
// This must be called with mutex_ aquired
void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf->data);
delete buf;
}
}
size_t CudaAllocator::get_active_memory() const {
return active_memory_;
}
size_t CudaAllocator::get_peak_memory() const {
return peak_memory_;
}
void CudaAllocator::reset_peak_memory() {
std::lock_guard lock(mutex_);
peak_memory_ = 0;
}
size_t CudaAllocator::get_memory_limit() {
return memory_limit_;
}
size_t CudaAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_);
return limit;
}
size_t CudaAllocator::get_cache_memory() const {
return buffer_cache_.cache_size();
}
size_t CudaAllocator::set_cache_limit(size_t limit) {
std::lock_guard lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
}
void CudaAllocator::clear_cache() {
std::lock_guard lk(mutex_);
buffer_cache_.clear();
}
CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static CudaAllocator* allocator_ = new CudaAllocator;
return *allocator_;
}
} // namespace cu
namespace allocator {
Allocator& allocator() {
return cu::allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
}
} // namespace allocator
size_t get_active_memory() {
return cu::allocator().get_active_memory();
}
size_t get_peak_memory() {
return cu::allocator().get_peak_memory();
}
void reset_peak_memory() {
return cu::allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return cu::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return cu::allocator().get_memory_limit();
}
size_t get_cache_memory() {
return cu::allocator().get_cache_memory();
}
size_t set_cache_limit(size_t limit) {
return cu::allocator().set_cache_limit(limit);
}
void clear_cache() {
cu::allocator().clear_cache();
}
// Not supported in CUDA.
size_t set_wired_limit(size_t) {
return 0;
}
} // namespace mlx::core

View File

@@ -1,77 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include <mutex>
#include <set>
#include <utility>
namespace mlx::core::cu {
using allocator::Buffer;
// Stores cuda-managed unified memory.
struct CudaBuffer {
void* data;
size_t size;
};
class SmallSizePool {
private:
union Block {
Block* next;
CudaBuffer buf;
};
Block* buffer_{nullptr};
void* data_{nullptr};
Block* next_free_{nullptr};
public:
SmallSizePool();
~SmallSizePool();
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
CudaBuffer* malloc();
void free(CudaBuffer* buf);
bool in_pool(CudaBuffer* buf);
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
size_t get_cache_memory() const;
size_t set_cache_limit(size_t limit);
void clear_cache();
private:
void cuda_free(CudaBuffer* buf);
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();
} // namespace mlx::core::cu

View File

@@ -1,69 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, typename IdxT, int N_WRITES>
__global__ void arange(T* out, IdxT size, T start, T step) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_WRITES > size) {
for (IdxT i = index * N_WRITES; i < size; ++i) {
out[i] = start + i * step;
}
} else {
AlignedVector<T, N_WRITES> out_vec;
#pragma unroll
for (int i = 0; i < N_WRITES; ++i) {
out_vec[i] = start + (index * N_WRITES + i) * step;
}
store_vector<N_WRITES>(out, index, out_vec);
}
}
} // namespace cu
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
constexpr int N_WRITES = 16 / sizeof(OutType);
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
encoder.add_kernel_node(
cu::arange<OutType, IdxT, N_WRITES>,
num_blocks,
block_dims,
0,
out.data<OutType>(),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
});
});
}
} // namespace mlx::core

View File

@@ -1,188 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
struct IndexValPair {
uint32_t index;
T val;
};
template <typename T>
struct ArgMin {
constexpr __device__ T init() {
return Limits<T>::max();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T>
struct ArgMax {
constexpr __device__ T init() {
return Limits<T>::min();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
__global__ void arg_reduce_general(
const T* in,
uint32_t* out,
size_t size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides in_strides,
const __grid_constant__ Strides out_strides,
int32_t ndim,
int64_t axis_stride,
int32_t axis_size) {
auto block = cg::this_thread_block();
int64_t index = cg::this_grid().block_rank();
if (index >= size) {
return;
}
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
in += in_idx;
Op op;
T init = op.init();
IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto tid = r * BLOCK_DIM + block.thread_index().x;
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
best = op.reduce_many(best, vals, tid * N_READS);
}
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp;
best = BlockReduceT(temp).Reduce(best, op);
if (block.thread_rank() == 0) {
out[out_idx] = best.index;
}
}
} // namespace cu
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
Strides in_strides = remove_index(in.strides(), axis_);
Strides out_strides = out.ndim() == in.ndim()
? remove_index(out.strides(), axis_)
: out.strides();
int64_t axis_stride = in.strides()[axis_];
int32_t axis_size = in.shape()[axis_];
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4;
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
auto kernel =
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) {
kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
out.size(),
const_param(shape),
const_param(in_strides),
const_param(out_strides),
ndim,
axis_stride,
axis_size);
});
});
}
} // namespace mlx::core

View File

@@ -1,150 +0,0 @@
# Based on: https://github.com/sivachandran/cmake-bin2h
#
# Copyright 2020 Sivachandran Paramasivam
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
include(CMakeParseArguments)
# Function to wrap a given string into multiple lines at the given column
# position.
#
# Parameters:
#
# * VARIABLE - The name of the CMake variable holding the string.
# * AT_COLUMN - The column position at which string will be wrapped.
function(WRAP_STRING)
set(oneValueArgs VARIABLE AT_COLUMN)
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
math(EXPR offset "0")
while(stringLength GREATER 0)
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
else()
math(EXPR length "${stringLength}")
endif()
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
set(lines "${lines}\n ${line}")
math(EXPR stringLength "${stringLength} - ${length}")
math(EXPR offset "${offset} + ${length}")
endwhile()
set(${WRAP_STRING_VARIABLE}
"${lines}"
PARENT_SCOPE)
endfunction()
# Function to embed contents of a file as byte array in C/C++ header file(.h).
# The header file will contain a byte array and integer variable holding the
# size of the array.
#
# Parameters:
#
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
# the header file.
# * VARIABLE_NAME - The name of the variable for the byte array. The string
# "_SIZE" will be append to this name and will be used a variable name for
# size variable.
# * HEADER_FILE - The path of header file.
# * APPEND - If specified appends to the header file instead of overwriting it
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
# array.
#
# Usage:
#
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
function(BIN2H)
set(options APPEND NULL_TERMINATE)
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
set(multiValueArgs SOURCE_FILES)
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
set(arrayDefinition "")
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
# get filename without extension
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
# convert the filename to a valid C identifier
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
# reads source file contents as hex string
file(READ ${SOURCE_FILE} hexString HEX)
# append null
if(BIN2H_NULL_TERMINATE)
string(APPEND hexString "00")
endif()
# wraps the hex string into multiple lines
wrap_string(VARIABLE hexString AT_COLUMN 24)
# strip the © in source code
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
${arrayValues})
# make a full variable name for the array
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
# declares byte array and the length variables
string(APPEND arrayDefinition
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
endforeach()
# add namespace wrapper if defined
if(DEFINED BIN2H_HEADER_NAMESPACE)
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
endif()
set(arrayIncludes "#pragma once")
string(PREPEND declarations "${arrayIncludes}\n\n")
if(BIN2H_APPEND)
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
else()
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
endif()
endfunction()
# ----------------------------- CLI args -----------------------------
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
foreach(source ${MLX_JIT_SOURCES_LIST})
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
endforeach()
bin2h(
SOURCE_FILES
${MLX_JIT_SOURCES_ABS}
NULL_TERMINATE
VARIABLE_NAME
"jit_source"
HEADER_NAMESPACE
"mlx::core"
HEADER_FILE
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")

View File

@@ -1,21 +0,0 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Add)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(ArcTan2)
} // namespace mlx::core

View File

@@ -1,379 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (int i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a[0], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[i]);
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a[0], b_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[0]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[i]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
return std::is_same_v<Out, bool>;
}
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, NaNEqual>) {
return std::is_same_v<Out, bool> && is_inexact_v<In>;
}
if (std::is_same_v<Op, LogAddExp>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, ArcTan2>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
std::is_same_v<Op, BitwiseXor>) {
return std::is_same_v<In, Out> && std::is_integral_v<In>;
}
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out.data_size() > INT32_MAX,
[&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape;
std::vector<Strides> strides;
std::tie(shape, strides) = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
rest,
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(InType);
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, name(), s); \
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More