mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
12 Commits
v0.29.4
...
27ff069175
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27ff069175 | ||
|
|
3b2ffcefc3 | ||
|
|
b65f882df3 | ||
|
|
b704e9e77a | ||
|
|
66519fb348 | ||
|
|
8973550ff3 | ||
|
|
3f866be665 | ||
|
|
23f81ed1c1 | ||
|
|
3fe2250c00 | ||
|
|
047114b988 | ||
|
|
9320eb89a8 | ||
|
|
75819d70ea |
@@ -1,579 +0,0 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
apple: ml-explore/pr-approval@0.1.0
|
||||
|
||||
parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
parameters:
|
||||
upload-docs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "26.0.0"
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
brew install python@3.10
|
||||
brew install doxygen
|
||||
python3.10 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
steps:
|
||||
- run:
|
||||
name: Build documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd docs && doxygen && make html O=-W
|
||||
- when:
|
||||
condition: << parameters.upload-docs >>
|
||||
steps:
|
||||
- add_ssh_keys:
|
||||
fingerprints:
|
||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
||||
- run:
|
||||
name: Upload documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
git config user.email "mlx@group.apple.com"
|
||||
git config user.name "CircleCI Docs"
|
||||
git checkout gh-pages
|
||||
git rebase main
|
||||
cd docs
|
||||
git rm -rf build/html
|
||||
doxygen && make html O=-W
|
||||
git add -f build/html
|
||||
git commit -m "rebase"
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Run style checks
|
||||
command: |
|
||||
pip install pre-commit
|
||||
pre-commit run --all
|
||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
python -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests
|
||||
|
||||
mac_build_and_test:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "26.0.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||
brew install openmpi uv
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv --python 3.10
|
||||
uv pip install \
|
||||
nanobind==2.4.0 \
|
||||
cmake \
|
||||
numpy \
|
||||
torch \
|
||||
tensorflow \
|
||||
unittest-xml-reporting
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cd examples/extensions
|
||||
uv pip install -r requirements.txt
|
||||
uv run --no-project setup.py build_ext --inplace
|
||||
uv run --no-project python test.py
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
uv run --no-project python -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
parameters:
|
||||
image_date:
|
||||
type: string
|
||||
default: "2023.11.1"
|
||||
machine:
|
||||
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install libnccl2 libnccl-dev
|
||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Set CCache size
|
||||
command: ccache --max-size 1G
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
- run:
|
||||
name: CCache report
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
paths:
|
||||
- /home/circleci/.cache/ccache
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.10"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "26.0.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: m4pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
mkdir -p ~/miniconda3
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
source ~/miniconda3/bin/activate
|
||||
conda init --all
|
||||
conda create -n env python=<< parameters.python_version >> -y
|
||||
conda activate env
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
pip install build
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
conda activate env
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
conda activate env
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.10", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
conda activate env
|
||||
twine upload dist/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_linux_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.10"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.build_env >> pip install ".[dev]" -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.10", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload packages
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
build_cuda_release:
|
||||
parameters:
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: xlarge
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
and:
|
||||
- matches:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
- build_linux_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^pull/\\d+(/head)?$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- mac_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.nightly_build >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
- build_cuda_release
|
||||
|
||||
build_dev_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
8
.github/actions/build-docs/action.yml
vendored
8
.github/actions/build-docs/action.yml
vendored
@@ -7,6 +7,12 @@ runs:
|
||||
- name: Setup machine
|
||||
uses: ./.github/actions/setup-macos
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
activate-environment: true
|
||||
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
@@ -35,4 +41,4 @@ runs:
|
||||
name: github-pages
|
||||
path: artifact.tar
|
||||
retention-days: 1
|
||||
if-no-files-found: error
|
||||
if-no-files-found: error
|
||||
|
||||
11
.github/actions/build-linux-release/action.yml
vendored
11
.github/actions/build-linux-release/action.yml
vendored
@@ -7,6 +7,13 @@ inputs:
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
arch:
|
||||
description: 'Platform architecture tag'
|
||||
required: true
|
||||
type: choice
|
||||
options:
|
||||
- x86_64
|
||||
- aarch64
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
@@ -23,11 +30,11 @@ runs:
|
||||
pip install auditwheel patchelf build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash
|
||||
run: |
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}
|
||||
|
||||
13
.github/actions/build-macos-release/action.yml
vendored
13
.github/actions/build-macos-release/action.yml
vendored
@@ -20,14 +20,17 @@ runs:
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
uv pip install build
|
||||
uv run --no-project setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 uv run -m build -w
|
||||
conda activate env
|
||||
pip install build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
uv run --no-project setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 uv run -m build -w
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
|
||||
25
.github/actions/build-macos/action.yml
vendored
25
.github/actions/build-macos/action.yml
vendored
@@ -1,14 +1,31 @@
|
||||
name: 'Build and Test on macOS'
|
||||
description: 'Build and test MLX on macOS'
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: 'Python version to use'
|
||||
required: false
|
||||
default: '3.10'
|
||||
macos-target:
|
||||
description: 'macOS target to build and test for'
|
||||
required: false
|
||||
default: '14.0'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
activate-environment: true
|
||||
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
env:
|
||||
DEBUG: 1
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
uv pip install --upgrade pip
|
||||
uv pip install cmake setuptools nanobind==2.4.0
|
||||
@@ -29,6 +46,7 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
@@ -38,6 +56,8 @@ runs:
|
||||
|
||||
- name: Build example extension
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
cd examples/extensions
|
||||
uv pip install -r requirements.txt
|
||||
@@ -46,6 +66,8 @@ runs:
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
@@ -62,6 +84,8 @@ runs:
|
||||
|
||||
- name: Build small binary with JIT
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
@@ -80,6 +104,7 @@ runs:
|
||||
DEVICE: gpu
|
||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||
METAL_DEBUG_ERROR_MODE: 0
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e . -v
|
||||
|
||||
12
.github/actions/setup-macos/action.yml
vendored
12
.github/actions/setup-macos/action.yml
vendored
@@ -1,12 +1,6 @@
|
||||
name: 'Setup macOS Environment'
|
||||
description: 'Install dependencies for macOS builds'
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: 'Python version to use'
|
||||
required: false
|
||||
default: '3.10'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
@@ -17,9 +11,3 @@ runs:
|
||||
- name: Verify MetalToolchain installed
|
||||
shell: bash
|
||||
run: xcodebuild -showComponent MetalToolchain
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
activate-environment: true
|
||||
|
||||
6
.github/workflows/nightly.yml
vendored
6
.github/workflows/nightly.yml
vendored
@@ -21,6 +21,7 @@ jobs:
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
arch: "x86_64"
|
||||
- name: Upload mlx artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
@@ -40,7 +41,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
runs-on: ubuntu-22.04
|
||||
runner:
|
||||
- ubuntu-22.04
|
||||
- ubuntu-22.04-arm
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
|
||||
13
.github/workflows/pull_request.yml
vendored
13
.github/workflows/pull_request.yml
vendored
@@ -14,7 +14,13 @@ jobs:
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
linux_build_and_test:
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
matrix:
|
||||
runner:
|
||||
- ubuntu-22.04
|
||||
- ubuntu-22.04-arm
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
@@ -22,12 +28,17 @@ jobs:
|
||||
|
||||
mac_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
macos-target: ["14.0", "15.0"]
|
||||
runs-on: [self-hosted, macos]
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
- uses: ./.github/actions/build-macos
|
||||
with:
|
||||
macos-target: ${{ matrix.macos-target }}
|
||||
|
||||
cuda_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
|
||||
50
.github/workflows/release.yml
vendored
50
.github/workflows/release.yml
vendored
@@ -5,6 +5,11 @@ on:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dev_release:
|
||||
description: "Do a dev release or regular release"
|
||||
required: true
|
||||
default: "false"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -12,9 +17,6 @@ permissions:
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pypi_env: ${{ github.event_name == 'push' && 'pypi' || 'test-pypi' }}
|
||||
pypi_url: ${{ github.event_name == 'push' && 'https://upload.pypi.org/legacy/' || 'https://test.pypi.org/legacy/' }}
|
||||
steps:
|
||||
- name: Set publishing variables
|
||||
run: echo "Publishing setup complete"
|
||||
@@ -45,9 +47,15 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
runs-on: ubuntu-22.04
|
||||
include:
|
||||
- runner: ubuntu-24.04
|
||||
arch: x64
|
||||
- runner: ubuntu-24.04-arm64
|
||||
arch: arm64
|
||||
runs-on: ${{ matrix.runner }}
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
@@ -56,6 +64,7 @@ jobs:
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
arch: ${{ matrix.arch }}
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
@@ -76,22 +85,27 @@ jobs:
|
||||
runs-on: [self-hosted, macos]
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
- uses: conda-incubator/setup-miniconda@v3
|
||||
with:
|
||||
miniconda-version: "latest"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
uv pip install --upgrade pip
|
||||
uv pip install cmake setuptools nanobind==2.4.0
|
||||
uv pip install -e . -v
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install -e . -v
|
||||
- name: Generate package stubs
|
||||
shell: bash
|
||||
run: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- name: Build macOS 14 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
@@ -119,6 +133,7 @@ jobs:
|
||||
runs-on: ubuntu-22-large
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
@@ -141,7 +156,7 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
@@ -159,7 +174,7 @@ jobs:
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
pypi-publish-cuda:
|
||||
name: Upload CUDA release to PyPI
|
||||
@@ -168,7 +183,7 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cuda
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
@@ -180,7 +195,7 @@ jobs:
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
pypi-publish-cpu:
|
||||
name: Upload CPU release to PyPI
|
||||
@@ -189,7 +204,7 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cpu
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
@@ -201,7 +216,7 @@ jobs:
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
pypi-publish-metal:
|
||||
name: Upload Metal release to PyPI
|
||||
@@ -210,7 +225,7 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-metal
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
@@ -222,5 +237,4 @@ jobs:
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
mx::Shape shape = {8, 8, 512, 512};
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
mx::Shape shape;
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
mx::Shape shape = {size, size};
|
||||
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
@@ -167,7 +167,7 @@ void array::copy_shared_buffer(
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
int64_t offset /* = 0 */) {
|
||||
array_desc_->data = other.array_desc_->data;
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
|
||||
@@ -439,7 +439,7 @@ class array {
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
int64_t offset = 0);
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
|
||||
@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
}
|
||||
// Normalize the offset
|
||||
if (data_offset < 0) {
|
||||
data_offset += in.data_size();
|
||||
}
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
int64_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
@@ -51,17 +47,24 @@ void slice(
|
||||
|
||||
// Calculate out strides, initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
int64_t data_end = 1;
|
||||
for (int i = 0; i < start_indices.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
|
||||
// Get the location of the end based on the inp strides and out.shape()
|
||||
int64_t low_idx = 0;
|
||||
int64_t high_idx = 0;
|
||||
for (int i = 0; i < inp_strides.size(); ++i) {
|
||||
auto delta = inp_strides[i] * (out.shape()[i] - 1);
|
||||
if (inp_strides[i] > 0) {
|
||||
high_idx += delta;
|
||||
} else {
|
||||
low_idx += delta;
|
||||
}
|
||||
}
|
||||
if (data_end < 0) {
|
||||
data_end += in.data_size();
|
||||
int64_t data_size = (high_idx - low_idx) + 1;
|
||||
if (data_size < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[slice] Computed invalid data size: " << data_size << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
size_t data_size = (data_end - data_offset);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
|
||||
@@ -281,7 +281,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
Dtype dtype = out.dtype();
|
||||
|
||||
// Search cache.
|
||||
ConvCacheKey cache_key{
|
||||
BytesKey<ConvCacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(dtype),
|
||||
vector_key(in.shape()),
|
||||
|
||||
@@ -44,13 +44,13 @@ inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
||||
// 1. The rest of array is filled with 0.
|
||||
// 2. This util can be used in .cpp files.
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
template <int NDIM = MAX_NDIM, typename T, template <typename U> class Vec>
|
||||
inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result = {};
|
||||
std::array<T, NDIM> result = {};
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ std::string build_kernel(
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos) {
|
||||
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 8192);
|
||||
kernel_source += default_header;
|
||||
@@ -81,17 +81,17 @@ std::string build_kernel(
|
||||
kernel_source += ",\n";
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
if (std::get<0>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ Shape ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_shape,\n";
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
if (std::get<1>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ Strides ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_strides,\n";
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
if (std::get<2>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ int ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_ndim,\n";
|
||||
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
|
||||
"[custom_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
std::tuple<bool, bool, bool> shape_info;
|
||||
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
|
||||
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
|
||||
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
|
||||
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice s) {
|
||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos(
|
||||
inputs.size(), {false, false, false});
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
@@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
|
||||
const array& in = checked_inputs[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
args.append(in);
|
||||
if (shape_info.shape) {
|
||||
if (std::get<0>(shape_info)) {
|
||||
args.append_ndim(in.shape());
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
if (std::get<1>(shape_info)) {
|
||||
args.append_ndim(in.strides());
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
if (std::get<2>(shape_info)) {
|
||||
args.append<int32_t>(in.ndim());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -368,7 +368,7 @@ void CommandEncoder::commit() {
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
cudaStreamSynchronize(stream_);
|
||||
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream_));
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||
|
||||
@@ -135,12 +135,19 @@ class LRUCache {
|
||||
};
|
||||
|
||||
// Turn a POD struct into a container key by doing bytes compare.
|
||||
//
|
||||
// Usage:
|
||||
// BytesKey<MyKey> key;
|
||||
// key.pod = { ... };
|
||||
template <typename T>
|
||||
struct BytesKey {
|
||||
T pod;
|
||||
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||
|
||||
BytesKey(T pod) : pod(std::move(pod)) {}
|
||||
BytesKey() {
|
||||
// Make sure the paddings between members are filled with 0.
|
||||
memset(&pod, 0, sizeof(T));
|
||||
}
|
||||
|
||||
BytesKey(const BytesKey& other) {
|
||||
memcpy(&pod, &other.pod, sizeof(T));
|
||||
|
||||
321
mlx/backend/cuda/scaled_dot_product_attention.cpp
Normal file
321
mlx/backend/cuda/scaled_dot_product_attention.cpp
Normal file
@@ -0,0 +1,321 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace fe = cudnn_frontend;
|
||||
|
||||
namespace {
|
||||
|
||||
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||
do { \
|
||||
auto error = cmd; \
|
||||
if (!error.is_good()) { \
|
||||
throw std::runtime_error( \
|
||||
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
std::vector<int64_t> normalized_strides(const array& x) {
|
||||
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
|
||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||
return strides;
|
||||
}
|
||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||
if (x.shape(i) == 1) {
|
||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
void set_tensor_attrs(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x) {
|
||||
tensor->set_uid(uid)
|
||||
.set_dim({x.shape().begin(), x.shape().end()})
|
||||
.set_stride(normalized_strides(x));
|
||||
}
|
||||
|
||||
array prepare_sdpa_input(const array& x, Stream s) {
|
||||
// SDPA kernel's requirements on inputs:
|
||||
// 1. last dim's stride be 1;
|
||||
// 2. pointer be aligned.
|
||||
if (x.strides(-1) != 1 || get_alignment(x) < 16) {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
constexpr int QKV_NDIM = 4;
|
||||
|
||||
struct SDPACacheKey {
|
||||
int device_id;
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
std::array<int, QKV_NDIM> q_shape;
|
||||
std::array<int, QKV_NDIM> k_shape;
|
||||
std::array<int, QKV_NDIM> v_shape;
|
||||
std::array<int64_t, QKV_NDIM> q_strides;
|
||||
std::array<int64_t, QKV_NDIM> k_strides;
|
||||
std::array<int64_t, QKV_NDIM> v_strides;
|
||||
bool do_causal;
|
||||
};
|
||||
|
||||
auto& sdpa_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
enum UIDS {
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
SCALE,
|
||||
O,
|
||||
};
|
||||
|
||||
fe::graph::Graph build_sdpa_graph(
|
||||
cudnnHandle_t handle,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const array& o) {
|
||||
auto dtype = fe::DataType_t::HALF;
|
||||
if (q.dtype() == bfloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
|
||||
fe::graph::Graph graph;
|
||||
graph.set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
|
||||
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
|
||||
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
|
||||
set_tensor_attrs(q_, Q, q);
|
||||
set_tensor_attrs(k_, K, k);
|
||||
set_tensor_attrs(v_, V, v);
|
||||
|
||||
auto scale = graph.tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Scale")
|
||||
.set_uid(SCALE)
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
|
||||
auto sdpa_options = fe::graph::SDPA_attributes()
|
||||
.set_name("sdpa_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal)
|
||||
.set_generate_stats(false);
|
||||
|
||||
auto [o_, _] = graph.sdpa(q_, k_, v_, sdpa_options);
|
||||
o_->set_output(true);
|
||||
set_tensor_attrs(o_, O, o);
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
|
||||
graph.select_behavior_notes(
|
||||
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool supports_sdpa_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
||||
if (!enabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// cuDNN SDPA requires Ampere and later.
|
||||
if (cu::device(s.device).compute_capability_major() < 8) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
// TODO: Support array masks.
|
||||
if (!do_causal) {
|
||||
return false;
|
||||
}
|
||||
// FIXME: Causal mask generates wrong results when L_Q != L_K.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Only use cuDNN for prefilling.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// D_qk and D_v must be a multiple of 8 with maximum value 128.
|
||||
if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) ||
|
||||
(v.shape(-1) > 128)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Dtype dtype = q.dtype();
|
||||
return dtype == float16 || dtype == bfloat16;
|
||||
}
|
||||
|
||||
void sdpa_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
float scale,
|
||||
array& o,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
// TODO: Handle donation.
|
||||
// TODO: Make O use same memory layout with Q.
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder.stream()));
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_output_array(o);
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
// Search cache.
|
||||
BytesKey<SDPACacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(q.dtype()),
|
||||
vector_key<QKV_NDIM>(q.shape()),
|
||||
vector_key<QKV_NDIM>(k.shape()),
|
||||
vector_key<QKV_NDIM>(v.shape()),
|
||||
vector_key<QKV_NDIM>(q.strides()),
|
||||
vector_key<QKV_NDIM>(k.strides()),
|
||||
vector_key<QKV_NDIM>(v.strides()),
|
||||
do_causal,
|
||||
};
|
||||
auto it = sdpa_cache().find(cache_key);
|
||||
if (it == sdpa_cache().end()) {
|
||||
it =
|
||||
sdpa_cache()
|
||||
.emplace(cache_key, build_sdpa_graph(handle, q, k, v, do_causal, o))
|
||||
.first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack{
|
||||
{Q, const_cast<void*>(gpu_ptr<void>(q))},
|
||||
{K, const_cast<void*>(gpu_ptr<void>(k))},
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{SCALE, &scale},
|
||||
{O, gpu_ptr<void>(o)}};
|
||||
|
||||
int64_t workspace_size = 0;
|
||||
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder.stream()),
|
||||
{static_cast<int>(workspace_size)},
|
||||
uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
CudaGraph cuda_graph(encoder.device());
|
||||
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
|
||||
handle, variant_pack, workspace_ptr, cuda_graph));
|
||||
encoder.add_graph_node(cuda_graph);
|
||||
}
|
||||
|
||||
// Defined in scaled_dot_product_attention.cu file.
|
||||
bool supports_sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal);
|
||||
void sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
float scale,
|
||||
array& o,
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks,
|
||||
Stream s);
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) &&
|
||||
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
|
||||
array q = prepare_sdpa_input(inputs[0], s);
|
||||
array k = prepare_sdpa_input(inputs[1], s);
|
||||
array v = prepare_sdpa_input(inputs[2], s);
|
||||
bool has_mask = inputs.size() - has_sinks_ > 3;
|
||||
bool has_arr_mask = has_mask && !do_causal_;
|
||||
|
||||
if (supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal_)) {
|
||||
if (has_sinks_) {
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s);
|
||||
} else {
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
||||
}
|
||||
} else {
|
||||
sdpa_cudnn(q, k, v, scale_, out, do_causal_, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -6,10 +6,6 @@
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
@@ -663,23 +659,13 @@ void sdpa_vector_fallback(
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
bool supports_sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool do_causal) {
|
||||
const int value_head_dim = v.shape(-1);
|
||||
const int query_head_dim = q.shape(-1);
|
||||
const int query_sequence_length = q.shape(2);
|
||||
@@ -691,29 +677,24 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
const bool supported_vector_config =
|
||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||
|
||||
const bool supported_config = supported_vector_config;
|
||||
|
||||
return has_arr_mask || !supported_config;
|
||||
return supported_vector_config && !has_arr_mask;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
void sdpa_vector(
|
||||
const array& q_pre,
|
||||
const array& k_pre,
|
||||
const array& v_pre,
|
||||
float scale,
|
||||
array& o,
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks_pre,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto& q_pre = inputs[0];
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
auto& o = out;
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(inputs.size());
|
||||
copies.reserve(4);
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
@@ -731,8 +712,8 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
std::optional<array> sinks = std::nullopt;
|
||||
if (has_sinks_) {
|
||||
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
||||
if (sinks_pre) {
|
||||
sinks = copy_unless(is_matrix_contiguous, sinks_pre.value());
|
||||
}
|
||||
|
||||
// We are in vector mode ie single query
|
||||
@@ -798,8 +779,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(
|
||||
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
|
||||
sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||
}
|
||||
|
||||
// Full attention mode should never reach here
|
||||
@@ -808,6 +788,4 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -44,7 +44,7 @@ void Worker::commit(cudaStream_t stream) {
|
||||
}
|
||||
signal_event_.record(stream);
|
||||
signal_event_.wait(signal_stream_);
|
||||
cudaLaunchHostFunc(signal_stream_, signal, this);
|
||||
CHECK_CUDA_ERROR(cudaLaunchHostFunc(signal_stream_, signal, this));
|
||||
}
|
||||
|
||||
void Worker::thread_fn() {
|
||||
|
||||
@@ -11,7 +11,7 @@ void slice_gpu(
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
const Stream&) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ std::string write_signature(
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
const std::vector<std::tuple<bool, bool, bool>>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
@@ -88,19 +88,19 @@ std::string write_signature(
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
if (std::get<0>(shape_infos[i])) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
if (std::get<1>(shape_infos[i])) {
|
||||
kernel_source +=
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
if (std::get<2>(shape_infos[i])) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
@@ -184,12 +184,12 @@ CustomKernelFunction metal_kernel(
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
std::tuple<bool, bool, bool> shape_info;
|
||||
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
|
||||
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
|
||||
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
@@ -388,15 +388,15 @@ void CustomKernel::eval_gpu(
|
||||
index++;
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
if (std::get<0>(shape_info)) {
|
||||
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
if (std::get<1>(shape_info)) {
|
||||
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
if (std::get<2>(shape_info)) {
|
||||
compute_encoder.set_bytes(ndim, index);
|
||||
index++;
|
||||
}
|
||||
|
||||
@@ -75,6 +75,14 @@ constexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>;
|
||||
template <typename T>
|
||||
constexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_optional =
|
||||
is_specialization_of<std::optional, std::decay_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_variant =
|
||||
is_specialization_of<std::variant, std::decay_t<T>>;
|
||||
|
||||
template <typename>
|
||||
constexpr bool dependent_false = false;
|
||||
|
||||
@@ -96,6 +104,12 @@ void reverse_bytes(T& data) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void serialize_variant(Writer& os, T v);
|
||||
|
||||
template <typename T>
|
||||
T deserialize_variant(Reader& is);
|
||||
|
||||
template <typename T>
|
||||
void serialize(Writer& os, T v) {
|
||||
if constexpr (std::is_arithmetic_v<T>) {
|
||||
@@ -113,6 +127,13 @@ void serialize(Writer& os, T v) {
|
||||
}
|
||||
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||
std::apply([&os](auto&... x) { (..., serialize(os, x)); }, v);
|
||||
} else if constexpr (is_variant<T>) {
|
||||
serialize_variant(os, v);
|
||||
} else if constexpr (is_optional<T>) {
|
||||
serialize(os, v.has_value());
|
||||
if (v.has_value()) {
|
||||
serialize(os, *v);
|
||||
}
|
||||
} else {
|
||||
NotSerializable<T>();
|
||||
}
|
||||
@@ -145,11 +166,58 @@ T deserialize(Reader& is) {
|
||||
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||
return deserialize_tuple<T>(
|
||||
is, std::make_index_sequence<std::tuple_size_v<std::decay_t<T>>>{});
|
||||
} else if constexpr (is_optional<T>) {
|
||||
auto has_value = deserialize<bool>(is);
|
||||
if (has_value) {
|
||||
return deserialize<T>(is);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
} else if constexpr (is_variant<T>) {
|
||||
return deserialize_variant<T>(is);
|
||||
} else {
|
||||
NotDeserializable<T>();
|
||||
}
|
||||
}
|
||||
|
||||
enum class VariantType { Int = 0, Float = 1, Bool = 2 };
|
||||
|
||||
template <typename T>
|
||||
void serialize_variant(Writer& os, T v) {
|
||||
std::visit(
|
||||
[&](auto&& x) {
|
||||
using ElemT = std::decay_t<decltype(x)>;
|
||||
if constexpr (std::is_same_v<ElemT, int>) {
|
||||
serialize(os, VariantType::Int);
|
||||
} else if constexpr (std::is_same_v<ElemT, float>) {
|
||||
serialize(os, VariantType::Float);
|
||||
} else if constexpr (std::is_same_v<ElemT, bool>) {
|
||||
serialize(os, VariantType::Bool);
|
||||
} else {
|
||||
static_assert(
|
||||
std::is_same_v<ElemT, void>, "Can't serialize variant type.");
|
||||
}
|
||||
serialize(os, x);
|
||||
},
|
||||
v);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T deserialize_variant(Reader& is) {
|
||||
auto vt = deserialize<VariantType>(is);
|
||||
switch (vt) {
|
||||
case VariantType::Int:
|
||||
return deserialize<int>(is);
|
||||
case VariantType::Float:
|
||||
return deserialize<float>(is);
|
||||
case VariantType::Bool:
|
||||
return deserialize<bool>(is);
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[deserialize_variant] Unknonw variant type tag.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, std::size_t... I>
|
||||
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>) {
|
||||
return T{deserialize<std::tuple_element_t<I, T>>(is)...};
|
||||
@@ -374,7 +442,8 @@ struct PrimitiveFactory {
|
||||
SERIALIZE_PRIMITIVE(LayerNorm),
|
||||
SERIALIZE_PRIMITIVE(LayerNormVJP),
|
||||
SERIALIZE_PRIMITIVE(RoPE),
|
||||
SERIALIZE_PRIMITIVE(ScaledDotProductAttention)};
|
||||
SERIALIZE_PRIMITIVE(ScaledDotProductAttention),
|
||||
SERIALIZE_PRIMITIVE(CustomKernel)};
|
||||
std::unordered_map<std::string, std::string> name_remap;
|
||||
|
||||
PrimitiveFactory() {
|
||||
@@ -647,7 +716,7 @@ void FunctionExporter::export_with_callback(
|
||||
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
|
||||
continue;
|
||||
}
|
||||
if (constants.insert(arr.id()).second) {
|
||||
if (constants.insert({arr.id(), arr}).second) {
|
||||
new_constants.emplace_back(namer.get_name(arr), arr);
|
||||
}
|
||||
}
|
||||
@@ -779,7 +848,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
if (input_set.find(arr.id()) == input_set.end()) {
|
||||
serialize(os, true);
|
||||
// Save constant data if not already saved
|
||||
if (constants.insert(arr.id()).second) {
|
||||
if (constants.insert({arr.id(), arr}).second) {
|
||||
serialize(os, arr.shape());
|
||||
serialize(os, arr.dtype());
|
||||
os.write(arr.data<char>(), arr.nbytes());
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
@@ -24,6 +25,9 @@ using StateT = std::variant<
|
||||
Strides,
|
||||
std::vector<int>,
|
||||
std::vector<size_t>,
|
||||
std::vector<std::tuple<bool, bool, bool>>,
|
||||
std::vector<std::variant<bool, int, float>>,
|
||||
std::optional<float>,
|
||||
std::string>;
|
||||
|
||||
using ExportCallbackInput = std::unordered_map<
|
||||
|
||||
@@ -72,7 +72,7 @@ struct FunctionExporter {
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<std::string>& kwarg_keys);
|
||||
std::set<std::uintptr_t> constants;
|
||||
std::unordered_map<std::uintptr_t, array> constants;
|
||||
int count{0};
|
||||
bool closed{false};
|
||||
std::shared_ptr<FunctionTable> ftable;
|
||||
|
||||
@@ -315,12 +315,6 @@ class Quantize : public Custom {
|
||||
bool dequantize_;
|
||||
};
|
||||
|
||||
struct CustomKernelShapeInfo {
|
||||
bool shape = false;
|
||||
bool strides = false;
|
||||
bool ndim = false;
|
||||
};
|
||||
|
||||
using ScalarArg = std::variant<bool, int, float>;
|
||||
|
||||
class CustomKernel : public Primitive {
|
||||
@@ -331,15 +325,15 @@ class CustomKernel : public Primitive {
|
||||
std::string source,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::vector<CustomKernelShapeInfo> shape_infos,
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos,
|
||||
bool ensure_row_contiguous,
|
||||
std::optional<float> init_value,
|
||||
std::vector<ScalarArg> scalar_arguments,
|
||||
bool is_precompiled,
|
||||
int shared_memory)
|
||||
: Primitive(stream),
|
||||
source_(std::move(source)),
|
||||
name_(std::move(name)),
|
||||
source_(std::move(source)),
|
||||
grid_(grid),
|
||||
threadgroup_(threadgroup),
|
||||
shape_infos_(std::move(shape_infos)),
|
||||
@@ -358,13 +352,26 @@ class CustomKernel : public Primitive {
|
||||
override;
|
||||
|
||||
DEFINE_NAME(CustomKernel);
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
name_,
|
||||
source_,
|
||||
grid_,
|
||||
threadgroup_,
|
||||
shape_infos_,
|
||||
ensure_row_contiguous_,
|
||||
init_value_,
|
||||
scalar_arguments_,
|
||||
is_precompiled_,
|
||||
shared_memory_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string source_;
|
||||
std::string name_;
|
||||
std::string source_;
|
||||
std::tuple<int, int, int> grid_;
|
||||
std::tuple<int, int, int> threadgroup_;
|
||||
std::vector<CustomKernelShapeInfo> shape_infos_;
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos_;
|
||||
bool ensure_row_contiguous_;
|
||||
std::optional<float> init_value_;
|
||||
std::vector<ScalarArg> scalar_arguments_;
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define MLX_VERSION_MAJOR 0
|
||||
#define MLX_VERSION_MINOR 29
|
||||
#define MLX_VERSION_PATCH 4
|
||||
#define MLX_VERSION_PATCH 5
|
||||
#define MLX_VERSION_NUMERIC \
|
||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
auditwheel repair dist/* \
|
||||
--plat manylinux_2_35_x86_64 \
|
||||
--plat manylinux_2_35_${1} \
|
||||
--only-plat \
|
||||
--exclude libmlx* \
|
||||
-w wheel_tmp
|
||||
|
||||
@@ -4335,7 +4335,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
|
||||
@@ -531,6 +531,71 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(keywords[0][0], "y")
|
||||
self.assertEqual(primitives, ["Subtract", "Abs", "Log"])
|
||||
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
|
||||
def test_export_import_custom_kernel(self):
|
||||
if mx.metal.is_available():
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
"""
|
||||
custom_kernel = mx.fast.metal_kernel
|
||||
elif mx.cuda.is_available():
|
||||
source = """
|
||||
auto elem = cooperative_groups::this_grid().thread_rank();
|
||||
out1[elem] = a[elem];
|
||||
"""
|
||||
custom_kernel = mx.fast.cuda_kernel
|
||||
|
||||
kernel = custom_kernel(
|
||||
name="basic",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def call(a):
|
||||
return kernel(
|
||||
inputs=[a],
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes=[(2, 2)],
|
||||
output_dtypes=[mx.float32],
|
||||
stream=mx.gpu,
|
||||
)[0]
|
||||
|
||||
mx.random.seed(7)
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
expected = call(a)
|
||||
mx.export_function(path, call, a)
|
||||
|
||||
imported = mx.import_function(path)
|
||||
|
||||
out = imported(a)[0]
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_export_import_multi_with_constants(self):
|
||||
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
def fun(y):
|
||||
i = y.shape[0]
|
||||
x = mx.array(i)
|
||||
for j in range(10):
|
||||
x = x + mx.array(i + j)
|
||||
return x * y.sum()
|
||||
|
||||
ys = [mx.array([1]), mx.array([1, 1]), mx.array([1, 1, 1])]
|
||||
|
||||
with mx.exporter(path, fun) as exporter:
|
||||
for y in ys:
|
||||
exporter(y)
|
||||
|
||||
imported = mx.import_function(path)
|
||||
for y in ys:
|
||||
self.assertEqual(imported(y)[0].item(), fun(y).item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -168,7 +168,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
Dk = 64
|
||||
|
||||
if self.is_apple_silicon or mx.cuda.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [63, 129, 400]:
|
||||
@@ -240,7 +240,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
B = 1
|
||||
H = 32
|
||||
dtypes = [np.float32]
|
||||
if self.is_apple_silicon or mx.cuda.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
|
||||
@@ -549,12 +549,8 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
class TestSDPA(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
def dtypes(self):
|
||||
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||
|
||||
def test_sdpa(self):
|
||||
if not mx.metal.is_available():
|
||||
if not mx.is_available(mx.gpu):
|
||||
return
|
||||
|
||||
# fmt: off
|
||||
@@ -578,10 +574,11 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_128
|
||||
dtypes = ["float32", "float16"]
|
||||
masks = [None, "additive", "bool", "causal"]
|
||||
transposes = (False, True)
|
||||
|
||||
for dtype in self.dtypes:
|
||||
for dtype in dtypes:
|
||||
for t in transposes:
|
||||
for mask_str in masks:
|
||||
for B, qL, kL, D, qH, kH in shapes:
|
||||
|
||||
@@ -3058,6 +3058,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = a[::-1]
|
||||
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
||||
|
||||
a = mx.arange(8)
|
||||
for _ in range(4):
|
||||
a = a[::-1]
|
||||
self.assertTrue(mx.array_equal(a, mx.arange(8)))
|
||||
|
||||
def test_complex_ops(self):
|
||||
x = mx.array(
|
||||
[
|
||||
|
||||
4
setup.py
4
setup.py
@@ -24,11 +24,11 @@ def get_version():
|
||||
if "#define MLX_VERSION_PATCH" in l:
|
||||
patch = l.split()[-1]
|
||||
version = f"{major}.{minor}.{patch}"
|
||||
if "PYPI_RELEASE" not in os.environ:
|
||||
if os.environ.get("PYPI_RELEASE", False):
|
||||
today = datetime.date.today()
|
||||
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
||||
|
||||
if "DEV_RELEASE" not in os.environ:
|
||||
if os.environ.get("DEV_RELEASE", False):
|
||||
git_hash = (
|
||||
run(
|
||||
"git rev-parse --short HEAD".split(),
|
||||
|
||||
@@ -425,9 +425,11 @@ TEST_CASE("test matrix pseudo-inverse") {
|
||||
const auto A = array({1.0, 2.0, 3.0, 4.0}, {2, 2});
|
||||
const auto A_pinv = linalg::pinv(A, Device::cpu);
|
||||
const auto A_again = matmul(matmul(A, A_pinv), A);
|
||||
CHECK(allclose(A_again, A).item<bool>());
|
||||
CHECK(allclose(A_again, A, /* rtol = */ 1e-5, /* atol = */ 1e-5)
|
||||
.item<bool>());
|
||||
const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);
|
||||
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
|
||||
CHECK(allclose(A_pinv_again, A_pinv, /* rtol = */ 1e-5, /* atol = */ 1e-5)
|
||||
.item<bool>());
|
||||
}
|
||||
{ // Rectangular matrix m < n
|
||||
const auto prng_key = random::key(42);
|
||||
@@ -437,9 +439,11 @@ TEST_CASE("test matrix pseudo-inverse") {
|
||||
CHECK_FALSE(allclose(zeros, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
const auto A_again = matmul(matmul(A, A_pinv), A);
|
||||
CHECK(allclose(A_again, A).item<bool>());
|
||||
CHECK(allclose(A_again, A, /* rtol = */ 1e-5, /* atol = */ 1e-5)
|
||||
.item<bool>());
|
||||
const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);
|
||||
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
|
||||
CHECK(allclose(A_pinv_again, A_pinv, /* rtol = */ 1e-5, /* atol = */ 1e-5)
|
||||
.item<bool>());
|
||||
}
|
||||
{ // Rectangular matrix m > n
|
||||
const auto prng_key = random::key(10);
|
||||
@@ -449,9 +453,11 @@ TEST_CASE("test matrix pseudo-inverse") {
|
||||
CHECK_FALSE(allclose(zeros2, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
const auto A_again = matmul(matmul(A, A_pinv), A);
|
||||
CHECK(allclose(A_again, A).item<bool>());
|
||||
CHECK(allclose(A_again, A, /* rtol = */ 1e-5, /* atol = */ 1e-5)
|
||||
.item<bool>());
|
||||
const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);
|
||||
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
|
||||
CHECK(allclose(A_pinv_again, A_pinv, /* rtol = */ 1e-5, /* atol = */ 1e-5)
|
||||
.item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -292,7 +292,7 @@ TEST_CASE("test slice") {
|
||||
|
||||
out = slice(x, {0}, {4}, {2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 4);
|
||||
CHECK_EQ(out.data_size(), 3);
|
||||
|
||||
x = ones({4, 4});
|
||||
out = slice(x, {0, 0}, {2, 4});
|
||||
@@ -325,6 +325,20 @@ TEST_CASE("test slice") {
|
||||
out = slice(x, {2, 2, 2}, {3, 4, 3});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 5);
|
||||
|
||||
x = ones({8});
|
||||
out = slice(x, {7}, {-9}, {-1});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 8);
|
||||
|
||||
out = slice(x, {7}, {-9}, {-1});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 8);
|
||||
|
||||
x = ones({4, 2});
|
||||
out = slice(x, {3, 0}, {-5, 2}, {-1, 1});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 8);
|
||||
}
|
||||
|
||||
TEST_CASE("test slice update") {
|
||||
|
||||
Reference in New Issue
Block a user