Compare commits

..

1 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a9c720e8cd Improve the ring backend initialization 2025-07-11 15:31:28 -07:00
422 changed files with 6659 additions and 22331 deletions

View File

@@ -7,9 +7,18 @@ parameters:
nightly_build:
type: boolean
default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
cuda_release:
type: boolean
default: false
jobs:
build_documentation:
@@ -18,17 +27,16 @@ jobs:
type: boolean
default: false
macos:
xcode: "26.0.0"
resource_class: m4pro.medium
xcode: "16.2.0"
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.10
brew install python@3.9
brew install doxygen
python3.10 -m venv env
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
@@ -65,9 +73,9 @@ jobs:
git push -f origin gh-pages
linux_build_and_test:
machine:
image: ubuntu-2204:current
resource_class: large
docker:
- image: cimg/python:3.9
steps:
- checkout
- run:
@@ -79,37 +87,35 @@ jobs:
- run:
name: Install dependencies
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
python3 setup.py develop
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
python -m unittest discover python/tests -v
python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
@@ -120,7 +126,7 @@ jobs:
parameters:
xcode_version:
type: string
default: "26.0.0"
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
@@ -128,56 +134,56 @@ jobs:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m4pro.medium
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
command: |
uv venv --python 3.10
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
source env/bin/activate
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
pip install -e . -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build example extension
command: |
source .venv/bin/activate
source env/bin/activate
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source .venv/bin/activate
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
@@ -186,7 +192,7 @@ jobs:
- run:
name: Build small binary
command: |
source .venv/bin/activate
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
@@ -198,85 +204,43 @@ jobs:
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
- run:
name: Run Python tests
command: |
source .venv/bin/activate
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
python_version:
type: string
default: "3.10"
default: "3.9"
xcode_version:
type: string
default: "26.0.0"
default: "16.2.0"
build_env:
type: string
default: ""
@@ -285,7 +249,7 @@ jobs:
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m4pro.medium
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
@@ -293,15 +257,11 @@ jobs:
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
@@ -311,38 +271,27 @@ jobs:
- run:
name: Install Python package
command: |
conda activate env
source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v
- run:
name: Generate package stubs
command: |
conda activate env
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
source env/bin/activate
<< parameters.build_env >> python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
conda activate env
source env/bin/activate
twine upload dist/*
- store_artifacts:
path: dist/
@@ -351,101 +300,89 @@ jobs:
parameters:
python_version:
type: string
default: "3.10"
build_env:
default: "3.9"
extra_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.build_env >> pip install ".[dev]" -v
<< parameters.extra_env >> pip install . -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
<< parameters.extra_env >> python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
python_version:
type: string
default: ""
default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
machine:
image: ubuntu-2204:current
resource_class: xlarge
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Build wheel
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
python -m venv env
source env/bin/activate
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
<< parameters.extra_env >> \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
pip install ".[dev]" -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
@@ -457,23 +394,22 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- cuda_build_and_test
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
@@ -484,10 +420,71 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
@@ -503,17 +500,8 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]
prb:
when:
@@ -529,14 +517,11 @@ workflows:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
@@ -546,34 +531,148 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
- build_cuda_release
build_dev_release:
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
- << pipeline.parameters.weekly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["DEV_RELEASE=1"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -1,24 +0,0 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
required: true
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
MLX_BUILD_STAGE: 2
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
python -m build -w
if [ -f "python/scripts/repair_cuda.sh" ]; then
bash python/scripts/repair_cuda.sh
fi

View File

@@ -1,68 +0,0 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
nvcc-location:
description: 'Location of nvcc compiler'
required: true
default: '/usr/local/cuda-12.9/bin/nvcc'
# this value is dependent on the CUDA tools installed in the setup-linux workflow
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Check if build actually worked
shell: bash
run: python -c "import mlx.core"
- name: Run Python tests - CPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: cpu
run: python -m unittest discover python/tests -v
- name: Run Python tests - GPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
run: python -m tests discover python/tests -v
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- name: Build Python package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: ${{ inputs.nvcc-location }}

View File

@@ -1,38 +0,0 @@
name: 'Build Documentation'
description: 'Build documentation on a mac'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-macos
- name: Install dependencies
shell: sh
run: |
brew install doxygen
uv pip install --upgrade pip cmake
uv pip install -r docs/requirements.txt
uv pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: sh
run: tar -cf artifact.tar --cd docs/build/html -L .
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

View File

@@ -1,78 +0,0 @@
name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
inputs:
build-type:
description: 'Build type'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
type: boolean
runs:
using: "composite"
steps:
- name: Set DEBUG
shell: sh
if: inputs.build-type == 'debug'
run: echo "DEBUG=1" >> $GITHUB_ENV
- name: Install Python package
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
run: pip install -e ".[dev]" -v
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
run: |
python -m unittest discover python/tests -v
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: sh
run: ./build/tests/tests
- name: Build Python package
if: inputs.build-type == 'release'
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
if [ -f "python/scripts/repair_linux.sh" ]; then
bash python/scripts/repair_linux.sh
fi
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64

View File

@@ -1,22 +0,0 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
runs:
using: "composite"
steps:
- name: Build Python package(s)
shell: bash
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv pip install build
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=1 uv run -m build -w
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=2 uv run -m build -w

View File

@@ -1,124 +0,0 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
build-jit:
description: 'Whether to build with JIT'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Install dependencies
shell: sh
env:
DEBUG: 1
DEV_RELEASE: 1
run: |
uv pip install --upgrade pip cmake setuptools
uv pip install nanobind==2.4.0 \
numpy torch tensorflow unittest-xml-reporting
uv pip install -e . -v
- name: Generate package stubs
shell: bash
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
if: inputs.run-tests == 'true'
shell: bash
run: |
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project test.py
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
if: inputs.build-jit == 'true'
shell: bash
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
if: ${{ inputs.build-jit == 'true' && inputs.run-tests == 'true' }}
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
uv run -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
- name: Build macOS 13 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 13.0
- name: Build macOS 14 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
- name: Build macOS 15 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0

View File

@@ -1,83 +0,0 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
runner-type:
description: 'Whether to set this up as a linux or CUDA runner'
required: false
default: 'linux'
type: choice
options:
- linux
- cuda
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Free disk space
shell: sh
if: inputs.runner-type == 'linux'
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Install common dependencies
env:
TZ: Etc/UTC
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
sudo apt autoremove -y
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
cache: 'pip'
- name: setup python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
pip install --upgrade pip cmake
- name: Install MPI
if: inputs.runner-type == 'linux'
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Network CUDA installation from packages
id: install-cuda
if: inputs.runner-type == 'cuda'
env:
TZ: Etc/UTC
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
- name: Package and Driver Report
if: inputs.runner-type == 'cuda'
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

View File

@@ -1,31 +0,0 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
install-mpi:
description: 'Whether to install MPI'
required: false
default: 'true'
type: boolean
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
if: inputs.install-mpi == 'true'
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- name: Setup uv
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ inputs.python-version }}
activate-environment: true

View File

@@ -1,6 +0,0 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

View File

@@ -1,28 +0,0 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

View File

@@ -1,93 +0,0 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
env:
MACOSX_DEPLOYMENT_TARGET: "15.0"
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
build_cuda_with_tests:
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_cuda_release:
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7

View File

@@ -1,46 +1,20 @@
name: Build and Test
on: pull_request
permissions:
contents: read
on:
pull_request:
branches:
- main
jobs:
check_lint:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
mac_build_and_test:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_documentation:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files

View File

@@ -1,188 +0,0 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
permissions:
contents: read
jobs:
build_documentation:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
with:
build-type: release
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiples: true
path: artifacts
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiples: true
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: build_cuda_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: build_linux_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: build_mac_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -1,10 +1,4 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
hooks:

View File

@@ -19,17 +19,11 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software
MLX leverages several third-party software, listed here together with

View File

@@ -26,7 +26,6 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -42,9 +41,7 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests -------------------------
message(
@@ -67,17 +64,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
endif()
else()
set(MLX_BUILD_METAL OFF)
endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
# ----------------------------- Lib -----------------------------
@@ -88,26 +78,22 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
if(METAL_LIB)
message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
@@ -116,8 +102,7 @@ if(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -146,12 +131,6 @@ if(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
@@ -179,7 +158,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate not found, using default backend.")
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
@@ -255,16 +234,12 @@ target_include_directories(
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)

View File

@@ -2,7 +2,7 @@
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
[**Examples**](#examples)
[**Examples**](#examples)
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient
to train and deploy models. The design of the framework itself is also
conceptually simple. We intend to make it easy for researchers to extend and
improve MLX with the goal of quickly exploring new ideas.
improve MLX with the goal of quickly exploring new ideas.
The design of MLX is inspired by frameworks like
[NumPy](https://numpy.org/doc/stable/index.html),
@@ -68,30 +68,25 @@ in the documentation.
## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
```bash
**With `pip`**:
```
pip install mlx
```
To install the CUDA backend on Linux, run:
**With `conda`**:
```bash
pip install mlx[cuda]
```
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
conda install -c conda-forge mlx
```
Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.
## Contributing
## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the
@@ -110,7 +105,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following
BibTex entry:
```text
```
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -142,7 +142,9 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -161,7 +163,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16", "complex64")
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
@@ -185,7 +187,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True):
for dtype in ("float32", "float16", "complex64"):
for dtype in ("float32", "float16"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
)
)
plt.close(fig)

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -1,5 +1,4 @@
sphinx
breathe
sphinx-book-theme
sphinx-copybutton
mlx

View File

@@ -18,7 +18,6 @@ release = version
# -- General configuration ---------------------------------------------------
extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",

View File

@@ -127,8 +127,7 @@ relying on a copy from ``ensure_row_contiguous``:
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source,
ensure_row_contiguous=False,
source=source
)
def exp_elementwise(a: mx.array):
@@ -139,6 +138,7 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]

View File

@@ -138,13 +138,13 @@ more concrete:
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** The name of primitive. */
const char* name() const override {
return "Axpby";
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/
@@ -394,14 +394,14 @@ below.
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::stream kname;
kname = "axpby_general_" + type_to_name(out);
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -70,7 +70,6 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/cuda
python/memory_management
python/nn
python/optimizers

View File

@@ -13,49 +13,32 @@ silicon computer is
pip install mlx
To install from PyPI your system must meet the following requirements:
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.10
- Using a native Python >= 3.9
- macOS >= 13.5
.. note::
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
MLX is also available on conda-forge. To install MLX with conda do:
.. code-block:: shell
conda install conda-forge::mlx
CUDA
^^^^
MLX has a CUDA backend which you can install with:
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.10
pip install mlx-cuda
Troubleshooting
@@ -271,7 +254,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag

View File

@@ -1,9 +0,0 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,4 +13,3 @@ Fast
rope
scaled_dot_product_attention
metal_kernel
cuda_kernel

View File

@@ -27,7 +27,6 @@ simple functions.
mish
prelu
relu
relu2
relu6
selu
sigmoid

View File

@@ -50,7 +50,6 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU2
ReLU6
RNN
RoPE

View File

@@ -112,7 +112,6 @@ Operations
max
maximum
mean
median
meshgrid
min
minimum

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors"))
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -19,4 +19,3 @@ Common Optimizers
Adamax
Lion
MultiOptimizer
Muon

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x)
timeit(mx.compile(gelu), x)
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]

View File

@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added
grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = tree_flatten(model.parameters(), destination={})
params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -164,13 +164,13 @@ to export a function which can be used for inputs with variable shapes:
.. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")
# Ok
out, = imported_abs(mx.array([-1.0]))
# Also ok
out, = imported_abs(mx.array(-1.0))
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -107,20 +107,8 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2025 Apple Inc.
#include <dlfcn.h>
#include <iostream>
#include <sstream>
@@ -17,19 +16,6 @@
namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
///////////////////////////////////////////////////////////////////////////////
@@ -181,15 +167,16 @@ void Axpby::eval_gpu(
}
// Resolve name of kernel (corresponds to axpby.metal)
std::string kname = "axpby_";
kname += (contiguous_kernel ? "contiguous_" : "general_");
kname += type_to_name(out);
std::ostringstream kname;
kname << "axpby_";
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override;
/** The name of primitive. */
const char* name() const override {
return "Axpby";
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.25
mlx>=0.21.0
nanobind==2.4.0
nanobind==2.2.0

View File

@@ -3,10 +3,8 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c_cpu.shape}")
print(f"c dtype: {c_cpu.dtype}")
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
explicit Buffer(void* ptr) : ptr_(ptr) {};
Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();

View File

@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->offset = other.array_desc_->offset;
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->offset = 0;
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
@@ -156,7 +156,7 @@ void array::set_data(
Flags flags,
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->offset = 0;
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
@@ -172,8 +172,9 @@ void array::copy_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
array_desc_->offset =
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
}
void array::copy_shared_buffer(const array& other) {
@@ -240,8 +241,8 @@ array::ArrayDesc::ArrayDesc(
std::vector<array> inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
init();
}

View File

@@ -10,7 +10,6 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
#include "mlx/small_vector.h"
namespace mlx::core {
@@ -19,8 +18,8 @@ class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t;
using Shape = SmallVector<ShapeElem>;
using Strides = SmallVector<int64_t>;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc
@@ -349,23 +348,15 @@ class array {
return array_desc_->data;
}
// Return a raw pointer to the arrays data. This function may do a copy if
// the underlying buffer is not accessible on the CPU. When accessing the
// data for GPU kernels, be sure to use the correct method / function for the
// given backend to access the GPU pointer.
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
return reinterpret_cast<T*>(
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
return static_cast<T*>(array_desc_->data_ptr);
}
template <typename T>
const T* data() const {
return const_cast<array&>(*this).data<T>();
}
int64_t offset() const {
return array_desc_->offset;
return static_cast<T*>(array_desc_->data_ptr);
}
enum Status {
@@ -469,8 +460,8 @@ class array {
// can share the underlying data buffer.
std::shared_ptr<Data> data;
// Offset from beginning of data pointer
int64_t offset{0};
// Properly offset data pointer
void* data_ptr{nullptr};
// The size in elements of the data buffer the array accesses
size_t data_size;

View File

@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
BinaryOpType bopt) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
out.set_data(
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
mallocfn(b.data_size() * out.itemsize()),
allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
mallocfn(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
mallocfn(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(mallocfn(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(allocator::malloc(0));
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);

View File

@@ -114,9 +114,7 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous,
const std::function<allocator::Buffer(size_t)>&
mallocfn /* = allocator::malloc */) {
bool contiguous) {
if (contiguous) {
int o = 0;
Strides strides;
@@ -142,7 +140,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
mallocfn(data_size * outputs[o].itemsize()),
allocator::malloc(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@@ -165,7 +163,7 @@ void compiled_allocate_outputs(
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
}
}
}

View File

@@ -58,9 +58,7 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous,
const std::function<allocator::Buffer(size_t)>& mallocfn =
allocator::malloc);
bool contiguous);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(

View File

@@ -22,11 +22,7 @@ enum class CopyType {
GeneralGeneral
};
inline bool set_copy_output_data(
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
@@ -35,14 +31,14 @@ inline bool set_copy_output_data(
return true;
} else {
out.set_data(
mallocfn(in.data_size() * out.itemsize()),
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(mallocfn(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
return false;
}
}

View File

@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}};
return {{1}, {0}, {0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
return {{1}, {0}, {0}, {0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};

View File

@@ -45,7 +45,7 @@ void slice(
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(allocator::malloc(0));
out.set_data(nullptr);
return;
}

View File

@@ -11,8 +11,6 @@ namespace mlx::core {
enum class TernaryOpType {
ScalarScalarScalar,
VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General,
};
@@ -27,14 +25,6 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else {
topt = TernaryOpType::General;
}
@@ -46,8 +36,7 @@ inline void set_ternary_op_output_data(
const array& b,
const array& c,
array& out,
TernaryOpType topt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
TernaryOpType topt) {
auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) {
out.copy_shared_buffer(x);
@@ -58,25 +47,24 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
out.set_data(
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
mallocfn(out.itemsize() * b.data_size()),
allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
}
break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General:
// Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(mallocfn(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -7,22 +7,19 @@
namespace mlx::core {
inline void set_unary_output_data(
const array& in,
array& out,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
inline void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
mallocfn(in.data_size() * out.itemsize()),
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(mallocfn(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
}

View File

@@ -1,20 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
std::filesystem::path current_binary_dir() {
static std::filesystem::path binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path();
}();
return binary_dir;
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(

View File

@@ -2,7 +2,6 @@
#pragma once
#include <filesystem>
#include <tuple>
#include <vector>
@@ -10,8 +9,7 @@
namespace mlx::core {
// Return the directory that contains current shared library.
std::filesystem::path current_binary_dir();
std::string get_primitive_string(Primitive* primitive);
inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
@@ -197,7 +195,7 @@ void shared_buffer_reshape(
array& out);
template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}

View File

@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// The decomposition is computed in place, so just copy the input to the
// output.
copy_cpu(
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -15,7 +15,6 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core {
@@ -95,11 +94,7 @@ void* compile(
kernel_file_name = kernel_name;
}
auto output_dir =
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
auto output_dir = std::filesystem::temp_directory_path();
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();
@@ -162,12 +157,10 @@ inline void build_kernel(
#endif
// Start the kernel
os << "void " << kernel_name
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
os << "void " << kernel_name << "(void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(i)) {
@@ -182,8 +175,8 @@ inline void build_kernel(
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const int64_t* " << xname << "_strides = strides["
<< strides_index++ << "];" << std::endl;
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
}
}
@@ -193,8 +186,10 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output size
if (contiguous) {
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
@@ -236,7 +231,7 @@ inline void build_kernel(
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
os << x.primitive().name();
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
@@ -295,6 +290,7 @@ void Compiled::eval_cpu(
// Collect function input arguments.
std::vector<void*> args;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) {
continue;
@@ -302,6 +298,9 @@ void Compiled::eval_cpu(
const auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
args.push_back(strides[strides_index++].data());
}
}
// Get the kernel name from the lib
@@ -336,20 +335,16 @@ void Compiled::eval_cpu(
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
if (contiguous) {
if (!contiguous) {
args.push_back((void*)shape.data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch([fun,
args = std::move(args),
strides = std::move(strides),
shape = std::move(shape)]() mutable {
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
shape = std::move(shape)]() mutable { fun(args.data()); });
}
} // namespace mlx::core

View File

@@ -883,7 +883,7 @@ void explicit_gemm_conv_1D_cpu(
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
@@ -895,7 +895,7 @@ void explicit_gemm_conv_1D_cpu(
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
temps.push_back(in_padded_slice);
// Make strided view
@@ -920,7 +920,7 @@ void explicit_gemm_conv_1D_cpu(
// Materialize strided view
Shape strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
copy(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
@@ -938,13 +938,13 @@ void explicit_gemm_conv_1D_cpu(
wt.size(),
0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);
copy(wt_transpose, gemm_wt, CopyType::General, stream);
temps.push_back(gemm_wt);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
copy(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
@@ -991,7 +991,132 @@ void explicit_gemm_conv_1D_cpu(
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
copy_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
auto conv_dtype = out.dtype();
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
temps.push_back(in_padded_slice);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
encoder.set_input_array(in_strided);
encoder.set_input_array(gemm_wt);
encoder.set_output_array(gemm_out);
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
gemm_wt_ptr = gemm_wt.data<float>(),
gemm_out_ptr = gemm_out.data<float>(),
strided_reshape = std::move(strided_reshape),
O]() {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided_ptr,
strided_reshape[1], // lda
gemm_wt_ptr,
strided_reshape[1], // ldb
0.0f, // beta
gemm_out_ptr,
O // ldc
);
});
// Copy results if needed
if (out.dtype() != float32) {
copy_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
@@ -1031,7 +1156,7 @@ void explicit_gemm_conv_ND_cpu(
// Fill with zeros
std::vector<array> temps = {array(0, conv_dtype)};
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = 0;
@@ -1048,7 +1173,7 @@ void explicit_gemm_conv_ND_cpu(
data_offset);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
temps.push_back(in_padded_slice);
// Make strided view
@@ -1087,7 +1212,7 @@ void explicit_gemm_conv_ND_cpu(
}
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
copy(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
@@ -1098,13 +1223,13 @@ void explicit_gemm_conv_ND_cpu(
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
copy(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (flip) {
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);
copy(gemm_wt, gemm_wt_, CopyType::Vector, stream);
temps.push_back(gemm_wt_);
// Calculate the total size of the spatial dimensions
@@ -1159,7 +1284,7 @@ void explicit_gemm_conv_ND_cpu(
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
copy_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}

View File

@@ -295,11 +295,7 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream) {
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
@@ -309,7 +305,7 @@ void copy_cpu_inplace(
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
}
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
bool donated = set_copy_output_data(src, dst, ctype);
if (donated && src.dtype() == dst.dtype()) {
// If the output has the same type as the input then there is nothing to
@@ -319,10 +315,10 @@ void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_cpu_inplace(src, dst, ctype, stream);
copy_inplace(src, dst, ctype, stream);
}
void copy_cpu_inplace(
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
@@ -377,10 +373,4 @@ void copy_cpu_inplace(
});
}
array contiguous_copy_cpu(const array& arr, Stream stream) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return arr_copy;
}
} // namespace mlx::core

View File

@@ -10,14 +10,10 @@
namespace mlx::core {
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream);
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace(
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
@@ -30,7 +26,4 @@ void copy_cpu_inplace(
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_cpu(const array& arr, Stream stream);
} // namespace mlx::core

View File

@@ -13,7 +13,9 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) {
return {arr, false};
} else {
return {contiguous_copy_cpu(arr, stream), true};
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
}
};
@@ -32,7 +34,8 @@ void AllReduce::eval_cpu(
}
return in;
} else {
array arr_copy = contiguous_copy_cpu(in, s);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy);
return arr_copy;
}

View File

@@ -46,6 +46,7 @@ void eig_impl(
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,
@@ -134,7 +135,7 @@ void Eig::eval_cpu(
: array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy_cpu(
copy(
a,
a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -196,7 +196,7 @@ void Eigh::eval_cpu(
values.set_data(allocator::malloc(values.nbytes()));
copy_cpu(
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h>
#include "mlx/array.h"
@@ -48,15 +49,9 @@ void matmul_bnns(
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
}
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,

View File

@@ -88,47 +88,4 @@ void matmul<double>(
}
}
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core

View File

@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.flags().row_contiguous && in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
copy_cpu(
copy(
in,
out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -517,7 +517,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream());
copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds;
@@ -686,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream());
copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx);

View File

@@ -115,7 +115,7 @@ void inverse_impl(
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy_cpu(
copy(
a,
inv,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include <algorithm>
#include <sstream>
#include <vector>

View File

@@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(gesvdx)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)

View File

@@ -87,7 +87,8 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy = contiguous_copy_cpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}

View File

@@ -31,7 +31,7 @@ void luf_impl(
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
copy_cpu_inplace(
copy_inplace(
a,
lu,
a.shape(),

View File

@@ -124,20 +124,21 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::Vector, s);
copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(false, stx, arr_copy, true);
}
return std::make_tuple(false, stx, arr, false);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::Vector, s);
copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(true, sty, arr_copy, true);
}
return std::make_tuple(true, sty, arr, false);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1);
array arr_copy = contiguous_copy_cpu(arr, s);
return std::make_tuple(false, stx, arr_copy, true);
}
};
@@ -215,18 +216,18 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(a);
encoder.set_input_array(b);
const void* a_mask_ptr = nullptr;
const void* b_mask_ptr = nullptr;
const void* out_mask_ptr = nullptr;
const void* a_mask_ptr;
const void* b_mask_ptr;
const void* out_mask_ptr;
Shape a_mask_shape;
Shape b_mask_shape;
Shape out_mask_shape;
Strides a_mask_strides;
Strides b_mask_strides;
Strides out_mask_strides;
bool a_mask_bool = false;
bool b_mask_bool = false;
bool out_mask_bool = false;
bool a_mask_bool;
bool b_mask_bool;
bool out_mask_bool;
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1];
@@ -385,7 +386,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return std::make_tuple(true, sty, arr);
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
copy(arr, temps.back(), CopyType::General, s);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back());
}
@@ -423,6 +424,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& rhs_indices = inputs[3];
auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides());
@@ -502,7 +504,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return std::make_tuple(true, sty, x);
} else {
array xc(x.shape(), x.dtype(), nullptr, {});
copy_cpu(x, xc, CopyType::General, s);
copy(x, xc, CopyType::General, s);
encoder.add_temporary(xc);
int64_t stx = x.shape(-1);
return std::make_tuple(false, stx, xc);

View File

@@ -81,7 +81,7 @@ void matmul_general(
return std::make_tuple(true, sty, arr);
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, stream);
copy(arr, temps.back(), CopyType::General, stream);
stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back());
}
@@ -91,6 +91,7 @@ void matmul_general(
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
@@ -107,9 +108,6 @@ void matmul_general(
} else if (out.dtype() == float64) {
matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}
@@ -130,6 +128,10 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return;
@@ -140,7 +142,7 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
copy(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}

View File

@@ -22,7 +22,7 @@ void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream());
copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}
@@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(in, out, ctype, stream());
copy(in, out, ctype, stream());
}
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
}
}
@@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_cpu(in, out, CopyType::General, stream());
copy(in, out, CopyType::General, stream());
}
}
@@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
ctype = CopyType::General;
}
copy_cpu(in, out, ctype, stream());
copy(in, out, ctype, stream());
}
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy_cpu(val, out, CopyType::Scalar, stream());
copy(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
size_t data_offset = 0;
@@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
}
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -333,14 +333,14 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(allocator::malloc(0));
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_cpu_inplace(
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
@@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(allocator::malloc(0));
out.set_data(nullptr);
return;
}
@@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu(
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto [out_offset, donated] =
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
copy_cpu_inplace(
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
@@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(allocator::malloc(0));
out.set_data(nullptr);
return;
}
@@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);
// Do copy
copy_cpu_inplace(
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
@@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream());
copy_inplace(in_tmp, tmp, CopyType::General, stream());
} else {
copy_cpu_inplace(in, tmp, CopyType::General, stream());
copy_inplace(in, tmp, CopyType::General, stream());
}
auto flags = out.flags();

View File

@@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream);
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes()));

View File

@@ -1,11 +1,10 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h"
#include <cassert>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -14,35 +13,6 @@ namespace mlx::core {
namespace {
const static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
@@ -437,229 +407,6 @@ void _qmm_dispatch(
}
}
template <typename T>
void mxfp4_qmm(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
for (int ng = 0; ng < packs_in_group; ng++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
}
}
result += N;
}
}
template <typename T>
void mxfp4_qmm_t(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
T gsum = 0;
for (int kw = 0; kw < packs_in_group; kw++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
sum += scale * gsum;
}
*result = sum;
result++;
}
x += K;
}
}
template <int S>
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
if constexpr (S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & 0xf;
simd::Simd<float, S> w_out;
for (int i = 0; i < S; ++i) {
w_out[i] = MXFP4_LUT[wi[i]];
}
return w_out;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
}
template <typename T>
void mxfp4_qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = 32 / 4;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
simd::Simd<float, S> g_acc(0);
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
// Extract bits
auto wf = mxfp4_extract_bits_simd<S>(w_local);
w_local += packs_per_simd;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
g_acc = g_acc + x_simd * wf;
x_local += S;
}
acc = acc + scale * g_acc;
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T>
void mxfp4_qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (simd::max_size<T> % 8 == 0) {
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
} else {
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
}
} else {
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
}
}
template <typename T>
void mxfp4_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
for (int i = 0; i < batch_size; i++) {
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
switch (x.dtype()) {
case bfloat16:
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
break;
case float16:
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
break;
case float32:
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T>
void _bs_qmm_dispatch_typed(
array& out,
@@ -766,198 +513,115 @@ void _bs_qmm_dispatch(
}
}
template <typename T>
void mxfp4_bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
switch (x.dtype()) {
case float32:
mxfp4_bs_qmm_dispatch_typed<float>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case float16:
mxfp4_bs_qmm_dispatch_typed<float16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case bfloat16:
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& encoder = cpu::get_command_encoder(stream());
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) {
std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_cpy, CopyType::General, s);
encoder.add_temporary(arr_cpy);
return arr_cpy;
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
transpose_ = transpose_]() mutable {
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
});
}
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& lhs_indices = inputs[inputs.size() - 2];
auto& rhs_indices = inputs[inputs.size() - 1];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> temps;
auto ensure_row_contiguous_last_dims = [s = stream(),
&encoder](const array& arr) {
&temps](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_cpy, CopyType::General, s);
encoder.add_temporary(arr_cpy);
return arr_cpy;
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
transpose_ = transpose_]() mutable {
mxfp4_bs_qmm_dispatch(
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
});
}
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
}
template <typename T, typename U>
@@ -1041,14 +705,16 @@ void dispatch_quantize(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
}
void fast::Quantize::eval_cpu(
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [s = stream()](const array& arr) {
if (arr.flags().row_contiguous) {
return std::make_pair(arr, false);
} else {
return std::make_pair(contiguous_copy_cpu(arr, s), true);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
}
};
@@ -1100,47 +766,7 @@ void fast::Quantize::eval_cpu(
}
} else {
throw std::runtime_error(
"[fast::Quantize::eval_cpu] Only supports floating point inputs");
}
});
}
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
});
}

View File

@@ -491,27 +491,19 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:

View File

@@ -250,8 +250,10 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// Ensure contiguity
auto in = inputs[0];
if (!in.flags().row_contiguous) {
in = contiguous_copy_cpu(in, stream());
encoder.add_temporary(in);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, stream());
in = arr_copy;
encoder.add_temporary(arr_copy);
}
out.set_data(allocator::malloc(out.nbytes()));

View File

@@ -1,6 +1,5 @@
#pragma once
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
@@ -10,7 +9,7 @@
#include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in simd/base_simd.h
// There seems to be a bug in sims/base.h
// __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15
@@ -201,15 +200,6 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);
@@ -244,7 +234,6 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
static_assert(std::is_same_v<MaskT, bool>);
if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) {
@@ -262,13 +251,9 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
return asd::pow(base.value, exp.value);
} else {
Simd<T, N> res = 1;
// Raising an integer to a negative power is undefined
if (any(exp < 0)) {
return 0;
}
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > 0, base * base, base);
while (any(exp)) {
res = select(exp & 1, res * base, res);
base = select(exp, base * base, base);
exp = exp >> 1;
}
return res;

View File

@@ -171,11 +171,6 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;

View File

@@ -131,7 +131,8 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
array x_copy = contiguous_copy_cpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -8,25 +8,13 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
return true;
}
return a < b;
}
template <typename T>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
@@ -39,7 +27,7 @@ struct StridedIterator {
StridedIterator() = default;
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: stride_(stride), ptr_(ptr + offset * stride) {}
: ptr_(ptr + offset * stride), stride_(stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
@@ -142,7 +130,7 @@ void sort(array& out, int axis) {
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed, nan_aware_less<T>);
std::stable_sort(st, ed);
src_it.step();
}
}
@@ -196,15 +184,6 @@ void argsort(const array& in, array& out, int axis) {
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b);
});
}
@@ -240,7 +219,7 @@ void partition(array& out, int axis, int kth) {
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed, nan_aware_less<T>);
std::nth_element(st, md, ed);
}
}
@@ -297,15 +276,6 @@ void argpartition(const array& in, array& out, int axis, int kth) {
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b);
});
}
@@ -363,24 +333,45 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
int axis = axis_;
if (axis < 0) {
axis += in.ndim();
}
// Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
dispatch_all_types(out.dtype(), [&](auto type_tag) {
sort<MLX_GET_TYPE(type_tag)>(out, axis);
});
});
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_:
return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -435,10 +426,8 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);

View File

@@ -31,7 +31,7 @@ void svd_impl(
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {});
copy_cpu(
copy(
a,
in,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
@@ -81,26 +81,40 @@ void svd_impl(
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto jobz = (u_ptr) ? "A" : "N";
auto job_u = (u_ptr) ? "V" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ jobz,
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
@@ -122,13 +136,20 @@ void svd_impl(
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
gesdd<T>(
/* jobz = */ jobz,
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
@@ -146,6 +167,13 @@ void svd_impl(
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
});
encoder.add_temporary(in);

View File

@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
simd::store(dst, Op{}(simd::load<T, N>(src)));
size -= N;
src += N;
dst += N;

View File

@@ -77,8 +77,7 @@ struct Real {
struct Sigmoid {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
return 1.0f / (1.0f + simd::exp(-x));
}
SINGLE()
};
@@ -108,73 +107,4 @@ struct Square {
SINGLE()
};
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail

View File

@@ -6,8 +6,8 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
@@ -15,27 +15,18 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
@@ -44,35 +35,15 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
else()
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Embed kernel sources in binary for JIT compilation.
@@ -115,18 +86,11 @@ endif()
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
# and requires drivers released after CUDA 12.4.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native")
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"70;80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
@@ -158,26 +122,6 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Use NVRTC and driver APIs.
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
# Use the frontend APIs of cuDNN.
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.14.0
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
FetchContent_MakeAvailable(cudnn)
target_link_libraries(mlx PRIVATE cudnn_frontend)
# Link with the actual cuDNN libraries.
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h>
@@ -17,101 +17,26 @@ namespace cu {
constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() {
auto num_blocks = small_pool_size / small_block_size;
buffer_ = new Block[num_blocks];
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {
curr->next = buffer_ + i;
curr = curr->next;
}
curr->next = nullptr;
}
SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(data_));
delete[] buffer_;
}
CudaBuffer* SmallSizePool::malloc() {
if (next_free_ == nullptr) {
return nullptr;
}
Block* b = next_free_;
uint64_t i = next_free_ - buffer_;
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
b->buf.device = -1;
return &b->buf;
}
void SmallSizePool::free(CudaBuffer* buf) {
auto b = reinterpret_cast<Block*>(buf);
b->next = next_free_;
next_free_ = b;
}
bool SmallSizePool::in_pool(CudaBuffer* buf) {
constexpr int num_blocks = (small_pool_size / small_block_size);
auto b = reinterpret_cast<Block*>(buf);
int64_t block_num = b - buffer_;
return block_num >= 0 && block_num < num_blocks;
}
CudaAllocator::CudaAllocator()
: buffer_cache_(
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
[this](CudaBuffer* buf) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_;
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
CHECK_CUDA_ERROR(cudaSetDevice(i));
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
}
}
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
} else if (size < page_size) {
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
@@ -119,34 +44,19 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure try to reclaim memory from the cache.
int64_t mem_to_free =
get_active_memory() + get_cache_memory() + size - memory_limit_;
if (mem_to_free > 0) {
buffer_cache_.release_cached_buffers(mem_to_free);
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}
// Try the scalar pool first
if (size <= small_block_size) {
buf = scalar_pool_.malloc();
}
lock.unlock();
if (!buf) {
int device = -1;
if (stream != nullptr) {
cudaStreamGetDevice(stream, &device);
}
buf = new CudaBuffer{nullptr, size, device};
cudaError_t err;
if (device == -1) {
err = cudaMallocManaged(&buf->data, size);
} else {
err = cudaMallocAsync(&buf->data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
lock.lock();
}
@@ -157,17 +67,10 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf};
}
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr);
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
@@ -179,7 +82,9 @@ void CudaAllocator::free(Buffer buffer) {
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
cuda_free(buf);
lock.unlock();
cuda_free(buf->data);
delete buf;
}
}
@@ -191,18 +96,27 @@ size_t CudaAllocator::size(Buffer buffer) const {
return buf->size;
}
// This must be called with mutex_ aquired
void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]);
} else {
cudaFree(buf->data);
void CudaAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
delete buf;
}
cudaFree(buf);
}
size_t CudaAllocator::get_active_memory() const {
@@ -251,16 +165,6 @@ CudaAllocator& allocator() {
return *allocator_;
}
Buffer malloc_async(size_t size, cudaStream_t stream) {
auto buffer = allocator().malloc_async(size, stream);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace cu
namespace allocator {
@@ -273,19 +177,7 @@ void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
if (cbuf.device != -1) {
// TODO maybe make this async on a i/o stream to avoid synchronizing the
// device on malloc/and free
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
cbuf.device = -1;
CHECK_CUDA_ERROR(
cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
CHECK_CUDA_ERROR(cudaFree(cbuf.data));
cbuf.data = new_data;
}
return cbuf.data;
return static_cast<cu::CudaBuffer*>(ptr_)->data;
}
} // namespace allocator

View File

@@ -4,54 +4,39 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/cuda/cuda_utils.h"
#include <cuda_runtime.h>
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::cu {
class Worker;
using allocator::Buffer;
// Stores cuda-managed unified memory.
struct CudaBuffer {
void* data;
size_t size;
int device; // -1 for managed
};
class SmallSizePool {
private:
union Block {
Block* next;
CudaBuffer buf;
};
Block* buffer_{nullptr};
void* data_{nullptr};
Block* next_free_{nullptr};
public:
SmallSizePool();
~SmallSizePool();
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
CudaBuffer* malloc();
void free(CudaBuffer* buf);
bool in_pool(CudaBuffer* buf);
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream);
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
@@ -62,24 +47,21 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf);
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream);
} // namespace mlx::core::cu

View File

@@ -1,68 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, typename IdxT, int N_WRITES>
__global__ void arange(T* out, IdxT size, T start, T step) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_WRITES > size) {
for (IdxT i = index * N_WRITES; i < size; ++i) {
out[i] = start + i * step;
}
} else {
AlignedVector<T, N_WRITES> out_vec;
#pragma unroll
for (int i = 0; i < N_WRITES; ++i) {
out_vec[i] = start + (index * N_WRITES + i) * step;
}
store_vector<N_WRITES>(out, index, out_vec);
}
}
} // namespace cu
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
if (out.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
constexpr int N_WRITES = 16 / sizeof(OutType);
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
encoder.add_kernel_node(
cu::arange<OutType, IdxT, N_WRITES>,
num_blocks,
block_dims,
0,
gpu_ptr<OutType>(out),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
});
});
}
} // namespace mlx::core

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
@@ -44,11 +44,8 @@ struct ArgMin {
}
template <int N>
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
@@ -77,11 +74,8 @@ struct ArgMax {
}
template <int N>
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
@@ -112,15 +106,16 @@ __global__ void arg_reduce_general(
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
in += in_idx;
Op op;
T init = op.init();
IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x;
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
}
@@ -140,10 +135,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
@@ -156,6 +149,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
@@ -172,9 +166,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
num_blocks,
block_dim(),
0,
gpu_ptr<T>(in),
gpu_ptr<uint32_t>(out),
in.data<T>(),
out.data<uint32_t>(),
out.size(),
const_param(shape),
const_param(in_strides),

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
@@ -28,7 +29,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a[0], b[0]);
out_vec.val[i] = Op{}(a[0], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -49,7 +50,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a[0], b_vec[i]);
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -70,7 +71,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b[0]);
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -92,96 +93,46 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -226,7 +177,7 @@ template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const char* op,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
@@ -259,61 +210,36 @@ void binary_op_gpu_inplace(
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
auto kernel = cu::
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
num_blocks,
block_dims,
0,
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
}
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
num_blocks,
block_dims,
0,
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
@@ -323,7 +249,8 @@ void binary_op_gpu_inplace(
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(InType);
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
@@ -333,15 +260,19 @@ void binary_op_gpu_inplace(
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
}
@@ -360,24 +291,72 @@ template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const char* op,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(a, b, out, bopt);
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, name(), s); \
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, op, s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
break;
}
}
} // namespace mlx::core

View File

@@ -1,21 +0,0 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Add)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(ArcTan2)
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Divide)
} // namespace mlx::core

View File

@@ -1,15 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Greater)
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More