mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
1 Commits
ibv-backen
...
steel-refa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c46e17a5d |
600
.circleci/config.yml
Normal file
600
.circleci/config.yml
Normal file
@@ -0,0 +1,600 @@
|
|||||||
|
version: 2.1
|
||||||
|
|
||||||
|
orbs:
|
||||||
|
apple: ml-explore/pr-approval@0.1.0
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
nightly_build:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
weekly_build:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
test_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
linux_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_documentation:
|
||||||
|
parameters:
|
||||||
|
upload-docs:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
macos:
|
||||||
|
xcode: "16.2.0"
|
||||||
|
resource_class: m2pro.medium
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install
|
||||||
|
command: |
|
||||||
|
brew install python@3.9
|
||||||
|
brew install doxygen
|
||||||
|
python3.9 -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install --upgrade cmake
|
||||||
|
pip install -r docs/requirements.txt
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||||
|
- when:
|
||||||
|
condition:
|
||||||
|
not: << parameters.upload-docs >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Build documentation
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
cd docs && doxygen && make html O=-W
|
||||||
|
- when:
|
||||||
|
condition: << parameters.upload-docs >>
|
||||||
|
steps:
|
||||||
|
- add_ssh_keys:
|
||||||
|
fingerprints:
|
||||||
|
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
||||||
|
- run:
|
||||||
|
name: Upload documentation
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
git config user.email "mlx@group.apple.com"
|
||||||
|
git config user.name "CircleCI Docs"
|
||||||
|
git checkout gh-pages
|
||||||
|
git rebase main
|
||||||
|
cd docs
|
||||||
|
git rm -rf build/html
|
||||||
|
doxygen && make html O=-W
|
||||||
|
git add -f build/html
|
||||||
|
git commit -m "rebase"
|
||||||
|
git push -f origin gh-pages
|
||||||
|
|
||||||
|
linux_build_and_test:
|
||||||
|
docker:
|
||||||
|
- image: cimg/python:3.9
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Run style checks
|
||||||
|
command: |
|
||||||
|
pip install pre-commit
|
||||||
|
pre-commit run --all
|
||||||
|
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||||
|
- run:
|
||||||
|
name: Install dependencies
|
||||||
|
command: |
|
||||||
|
pip install --upgrade cmake
|
||||||
|
pip install nanobind==2.4.0
|
||||||
|
pip install numpy
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
- run:
|
||||||
|
name: Install Python package
|
||||||
|
command: |
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
python3 setup.py build_ext --inplace
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
python3 setup.py develop
|
||||||
|
- run:
|
||||||
|
name: Generate package stubs
|
||||||
|
command: |
|
||||||
|
echo "stubs"
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
python3 -m unittest discover python/tests -v
|
||||||
|
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||||
|
- run:
|
||||||
|
name: Build CPP only
|
||||||
|
command: |
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||||
|
make -j `nproc`
|
||||||
|
- run:
|
||||||
|
name: Run CPP tests
|
||||||
|
command: ./build/tests/tests
|
||||||
|
|
||||||
|
mac_build_and_test:
|
||||||
|
parameters:
|
||||||
|
xcode_version:
|
||||||
|
type: string
|
||||||
|
default: "16.2.0"
|
||||||
|
macosx_deployment_target:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
macos:
|
||||||
|
xcode: << parameters.xcode_version >>
|
||||||
|
environment:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||||
|
resource_class: m2pro.medium
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install dependencies
|
||||||
|
command: |
|
||||||
|
brew install python@3.9
|
||||||
|
brew install openmpi
|
||||||
|
python3.9 -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install --upgrade cmake
|
||||||
|
pip install nanobind==2.4.0
|
||||||
|
pip install numpy
|
||||||
|
pip install torch
|
||||||
|
pip install tensorflow
|
||||||
|
pip install unittest-xml-reporting
|
||||||
|
- run:
|
||||||
|
name: Install Python package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
|
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
|
pip install -e . -v
|
||||||
|
- run:
|
||||||
|
name: Generate package stubs
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||||
|
- run:
|
||||||
|
name: Build example extension
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
cd examples/extensions
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python setup.py build_ext -j8
|
||||||
|
- store_test_results:
|
||||||
|
path: test-results
|
||||||
|
- run:
|
||||||
|
name: Build CPP only
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||||
|
- run:
|
||||||
|
name: Run CPP tests
|
||||||
|
command: |
|
||||||
|
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||||
|
- run:
|
||||||
|
name: Build small binary
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
cd build/
|
||||||
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DMLX_BUILD_CPU=OFF \
|
||||||
|
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||||
|
-DMLX_BUILD_GGUF=OFF \
|
||||||
|
-DMLX_METAL_JIT=ON
|
||||||
|
make -j `sysctl -n hw.ncpu`
|
||||||
|
- run:
|
||||||
|
name: Run Python tests with JIT
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
|
pip install -e . -v
|
||||||
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||||
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
|
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||||
|
|
||||||
|
build_release:
|
||||||
|
parameters:
|
||||||
|
python_version:
|
||||||
|
type: string
|
||||||
|
default: "3.9"
|
||||||
|
xcode_version:
|
||||||
|
type: string
|
||||||
|
default: "16.2.0"
|
||||||
|
build_env:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
macosx_deployment_target:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
macos:
|
||||||
|
xcode: << parameters.xcode_version >>
|
||||||
|
resource_class: m2pro.medium
|
||||||
|
environment:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install dependencies
|
||||||
|
command: |
|
||||||
|
brew install python@<< parameters.python_version >>
|
||||||
|
brew install openmpi
|
||||||
|
python<< parameters.python_version >> -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install --upgrade cmake
|
||||||
|
pip install nanobind==2.4.0
|
||||||
|
pip install --upgrade setuptools
|
||||||
|
pip install numpy
|
||||||
|
pip install twine
|
||||||
|
pip install build
|
||||||
|
- run:
|
||||||
|
name: Install Python package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
|
pip install . -v
|
||||||
|
- run:
|
||||||
|
name: Generate package stubs
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
- run:
|
||||||
|
name: Build Python package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
<< parameters.build_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
|
python -m build -w
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload dist/*
|
||||||
|
- store_artifacts:
|
||||||
|
path: dist/
|
||||||
|
|
||||||
|
build_linux_release:
|
||||||
|
parameters:
|
||||||
|
python_version:
|
||||||
|
type: string
|
||||||
|
default: "3.9"
|
||||||
|
extra_env:
|
||||||
|
type: string
|
||||||
|
default: "DEV_RELEASE=1"
|
||||||
|
docker:
|
||||||
|
- image: ubuntu:20.04
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Build wheel
|
||||||
|
command: |
|
||||||
|
PYTHON=python<< parameters.python_version >>
|
||||||
|
apt-get update
|
||||||
|
apt-get upgrade -y
|
||||||
|
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||||
|
apt-get install -y apt-utils
|
||||||
|
apt-get install -y software-properties-common
|
||||||
|
add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
|
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||||
|
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
apt-get install -y build-essential git
|
||||||
|
$PYTHON -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install --upgrade cmake
|
||||||
|
pip install nanobind==2.4.0
|
||||||
|
pip install --upgrade setuptools
|
||||||
|
pip install numpy
|
||||||
|
pip install auditwheel
|
||||||
|
pip install patchelf
|
||||||
|
pip install build
|
||||||
|
pip install twine
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
pip install . -v
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
python -m build --wheel
|
||||||
|
auditwheel show dist/*
|
||||||
|
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||||
|
- run:
|
||||||
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload wheelhouse/*
|
||||||
|
- store_artifacts:
|
||||||
|
path: wheelhouse/
|
||||||
|
|
||||||
|
workflows:
|
||||||
|
build_and_test:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- matches:
|
||||||
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
|
- not: << pipeline.parameters.weekly_build >>
|
||||||
|
- not: << pipeline.parameters.test_release >>
|
||||||
|
jobs:
|
||||||
|
- mac_build_and_test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
|
- linux_build_and_test
|
||||||
|
- build_documentation
|
||||||
|
|
||||||
|
build_pypi_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
|
- not: << pipeline.parameters.weekly_build >>
|
||||||
|
- not: << pipeline.parameters.test_release >>
|
||||||
|
jobs:
|
||||||
|
- build_release:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
|
exclude:
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- build_documentation:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
upload-docs: true
|
||||||
|
|
||||||
|
prb:
|
||||||
|
when:
|
||||||
|
matches:
|
||||||
|
pattern: "^pull/\\d+(/head)?$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
|
jobs:
|
||||||
|
- hold:
|
||||||
|
type: approval
|
||||||
|
- apple/authenticate:
|
||||||
|
context: pr-approval
|
||||||
|
- mac_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
|
- linux_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
|
nightly_build:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.nightly_build >>
|
||||||
|
jobs:
|
||||||
|
- build_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
|
exclude:
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
weekly_build:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.weekly_build >>
|
||||||
|
jobs:
|
||||||
|
- build_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
|
exclude:
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
linux_test_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.linux_release >>
|
||||||
|
jobs:
|
||||||
|
- build_linux_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
15
.github/actions/build-cuda-release/action.yml
vendored
15
.github/actions/build-cuda-release/action.yml
vendored
@@ -1,15 +0,0 @@
|
|||||||
name: 'Build CUDA wheel'
|
|
||||||
description: 'Build CUDA wheel'
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Build package
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
|
|
||||||
run: |
|
|
||||||
pip install auditwheel build patchelf setuptools
|
|
||||||
python setup.py clean --all
|
|
||||||
MLX_BUILD_STAGE=2 python -m build -w
|
|
||||||
bash python/scripts/repair_cuda.sh
|
|
||||||
38
.github/actions/build-docs/action.yml
vendored
38
.github/actions/build-docs/action.yml
vendored
@@ -1,38 +0,0 @@
|
|||||||
name: 'Build Documentation'
|
|
||||||
description: 'Build documentation'
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Setup machine
|
|
||||||
uses: ./.github/actions/setup-linux
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
sudo apt-get install -y doxygen
|
|
||||||
source .venv/bin/activate
|
|
||||||
pip install -r docs/requirements.txt
|
|
||||||
pip install . -v
|
|
||||||
|
|
||||||
- name: Build documentation
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
cd docs
|
|
||||||
doxygen
|
|
||||||
make html O=-W
|
|
||||||
|
|
||||||
- name: Create artifact tar
|
|
||||||
shell: bash
|
|
||||||
run: tar -cf artifact.tar -C docs --dereference build/html index.html
|
|
||||||
|
|
||||||
# Do it manually because upload-pages-artifact requires gtar
|
|
||||||
- name: Upload artifact
|
|
||||||
id: upload-artifact
|
|
||||||
uses: actions/upload-artifact@v5
|
|
||||||
with:
|
|
||||||
name: github-pages
|
|
||||||
path: artifact.tar
|
|
||||||
retention-days: 1
|
|
||||||
if-no-files-found: error
|
|
||||||
40
.github/actions/build-linux-release/action.yml
vendored
40
.github/actions/build-linux-release/action.yml
vendored
@@ -1,40 +0,0 @@
|
|||||||
name: 'Build Linux wheel'
|
|
||||||
description: 'Build Linux wheel'
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
build-backend:
|
|
||||||
description: 'Build the backend mlx-cpu package'
|
|
||||||
type: boolean
|
|
||||||
required: false
|
|
||||||
default: false
|
|
||||||
arch:
|
|
||||||
description: 'Platform architecture tag'
|
|
||||||
required: true
|
|
||||||
type: choice
|
|
||||||
options:
|
|
||||||
- x86_64
|
|
||||||
- aarch64
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Generate package stubs
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
pip install -e ".[dev]" -v
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
- name: Build Python package
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
pip install auditwheel patchelf build
|
|
||||||
python setup.py clean --all
|
|
||||||
MLX_BUILD_STAGE=1 python -m build -w
|
|
||||||
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
|
|
||||||
- name: Build backend package
|
|
||||||
if: ${{ inputs.build-backend }}
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
python setup.py clean --all
|
|
||||||
MLX_BUILD_STAGE=2 python -m build -w
|
|
||||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}
|
|
||||||
41
.github/actions/build-linux/action.yml
vendored
41
.github/actions/build-linux/action.yml
vendored
@@ -1,41 +0,0 @@
|
|||||||
name: 'Build and Test on Linux'
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
toolkit:
|
|
||||||
description: 'The toolkit to build with'
|
|
||||||
required: false
|
|
||||||
default: 'cpu'
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Install Python package
|
|
||||||
id: python_build
|
|
||||||
shell: sh
|
|
||||||
env:
|
|
||||||
DEBUG: 1
|
|
||||||
CMAKE_ARGS: >-
|
|
||||||
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
|
|
||||||
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
|
|
||||||
run: |
|
|
||||||
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
|
|
||||||
# There is no GPU in arm64 runner, use a common arch.
|
|
||||||
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
|
|
||||||
# Can not build tests when the built executables can not run.
|
|
||||||
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
|
|
||||||
fi
|
|
||||||
pip install --no-build-isolation -e ".[dev]" -v
|
|
||||||
# Pass the CMAKE_ARGS to following steps.
|
|
||||||
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Generate package stubs
|
|
||||||
shell: sh
|
|
||||||
run: |
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
|
|
||||||
- name: Build CPP only
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
|
|
||||||
cmake --build build -j $(nproc)
|
|
||||||
34
.github/actions/build-macos-release/action.yml
vendored
34
.github/actions/build-macos-release/action.yml
vendored
@@ -1,34 +0,0 @@
|
|||||||
name: 'Build macOS release'
|
|
||||||
description: 'Build MLX releases macOS'
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
macos-target:
|
|
||||||
description: 'macOS build target'
|
|
||||||
required: false
|
|
||||||
default: '15.0'
|
|
||||||
build-backend:
|
|
||||||
description: 'Build the backend mlx-metal package'
|
|
||||||
type: boolean
|
|
||||||
required: false
|
|
||||||
default: false
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Build Python package
|
|
||||||
shell: bash -l {0}
|
|
||||||
env:
|
|
||||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
|
||||||
run: |
|
|
||||||
pip install build
|
|
||||||
python setup.py clean --all
|
|
||||||
MLX_BUILD_STAGE=1 python -m build -w
|
|
||||||
|
|
||||||
- name: Build backend package
|
|
||||||
if: ${{ inputs.build-backend }}
|
|
||||||
shell: bash -l {0}
|
|
||||||
env:
|
|
||||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
|
||||||
run: |
|
|
||||||
python setup.py clean --all
|
|
||||||
MLX_BUILD_STAGE=2 python -m build -w
|
|
||||||
88
.github/actions/build-macos/action.yml
vendored
88
.github/actions/build-macos/action.yml
vendored
@@ -1,88 +0,0 @@
|
|||||||
name: 'Build and Test on macOS'
|
|
||||||
description: 'Build and test MLX on macOS'
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Install dependencies
|
|
||||||
env:
|
|
||||||
DEBUG: 1
|
|
||||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
|
||||||
shell: bash -l {0}
|
|
||||||
run: |
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install cmake setuptools nanobind==2.4.0
|
|
||||||
pip install -e . -v
|
|
||||||
|
|
||||||
- name: Generate package stubs
|
|
||||||
shell: bash -l {0}
|
|
||||||
run: |
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
|
|
||||||
- name: Install tests dependencies
|
|
||||||
shell: bash -l {0}
|
|
||||||
run: |
|
|
||||||
pip install numpy torch tensorflow unittest-xml-reporting
|
|
||||||
|
|
||||||
- name: Run Python tests
|
|
||||||
shell: bash -l {0}
|
|
||||||
env:
|
|
||||||
LOW_MEMORY: 1
|
|
||||||
run: |
|
|
||||||
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
|
||||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
|
||||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
|
||||||
|
|
||||||
- name: Build example extension
|
|
||||||
shell: bash -l {0}
|
|
||||||
run: |
|
|
||||||
cd examples/extensions
|
|
||||||
pip install -r requirements.txt
|
|
||||||
python setup.py build_ext --inplace
|
|
||||||
python test.py
|
|
||||||
|
|
||||||
- name: Build CPP only
|
|
||||||
shell: bash -l {0}
|
|
||||||
run: |
|
|
||||||
mkdir -p build
|
|
||||||
cd build
|
|
||||||
cmake ..
|
|
||||||
make -j $(sysctl -n hw.ncpu)
|
|
||||||
|
|
||||||
- name: Run CPP tests
|
|
||||||
shell: bash -l {0}
|
|
||||||
env:
|
|
||||||
DEVICE: gpu
|
|
||||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
|
||||||
METAL_DEBUG_ERROR_MODE: 0
|
|
||||||
run: ./build/tests/tests
|
|
||||||
|
|
||||||
- name: Build small binary with JIT
|
|
||||||
shell: bash -l {0}
|
|
||||||
run: |
|
|
||||||
mkdir -p build
|
|
||||||
cd build
|
|
||||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
|
||||||
-DMLX_BUILD_CPU=OFF \
|
|
||||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
|
||||||
-DMLX_BUILD_GGUF=OFF \
|
|
||||||
-DMLX_METAL_JIT=ON
|
|
||||||
make -j $(sysctl -n hw.ncpu)
|
|
||||||
|
|
||||||
- name: Run Python tests with JIT
|
|
||||||
shell: bash -l {0}
|
|
||||||
env:
|
|
||||||
LOW_MEMORY: 1
|
|
||||||
DEVICE: gpu
|
|
||||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
|
||||||
METAL_DEBUG_ERROR_MODE: 0
|
|
||||||
run: |
|
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
|
||||||
pip install -e . -v
|
|
||||||
python -m xmlrunner discover \
|
|
||||||
-v python/tests \
|
|
||||||
-o test-results/gpu_jit
|
|
||||||
86
.github/actions/setup-linux/action.yml
vendored
86
.github/actions/setup-linux/action.yml
vendored
@@ -1,86 +0,0 @@
|
|||||||
name: 'Setup Linux Environment'
|
|
||||||
description: 'Install dependencies for Linux builds'
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
toolkit:
|
|
||||||
description: 'Which toolkit to install'
|
|
||||||
required: false
|
|
||||||
default: 'cpu'
|
|
||||||
python-version:
|
|
||||||
description: 'Version of python to set up'
|
|
||||||
required: false
|
|
||||||
default: '3.10'
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Use ccache
|
|
||||||
uses: hendrikmuhs/ccache-action@v1.2
|
|
||||||
with:
|
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
|
||||||
max-size: 1GB
|
|
||||||
|
|
||||||
- name: Install common dependencies
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
|
||||||
|
|
||||||
- uses: actions/setup-python@v6
|
|
||||||
with:
|
|
||||||
python-version: ${{ inputs.python-version }}
|
|
||||||
|
|
||||||
- name: Setup Python venv
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
python -m venv .venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
pip install setuptools cmake nanobind==2.4.0
|
|
||||||
echo PATH=$PATH >> $GITHUB_ENV
|
|
||||||
# Make cmake search .venv for nanobind
|
|
||||||
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Install MPI
|
|
||||||
shell: bash
|
|
||||||
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
|
|
||||||
|
|
||||||
- name: Install CUDA toolkit
|
|
||||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
|
||||||
# Compatibility matrix:
|
|
||||||
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
|
||||||
PACKAGES: |
|
|
||||||
{
|
|
||||||
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
|
||||||
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
|
|
||||||
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
|
||||||
}
|
|
||||||
run: |
|
|
||||||
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
|
|
||||||
# Jetson specific. SBSA means Arm Server Base System Architecture.
|
|
||||||
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
|
|
||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y \
|
|
||||||
libnccl2 libnccl-dev \
|
|
||||||
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
|
||||||
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
|
|
||||||
|
|
||||||
- name: CUDA packages and driver report
|
|
||||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
sudo apt-get install -y ubuntu-drivers-common dkms
|
|
||||||
echo "NVIDIA Driver Packages Available:"
|
|
||||||
sudo ubuntu-drivers list --gpgpu
|
|
||||||
echo "NVIDIA Driver Version:"
|
|
||||||
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
|
|
||||||
echo "Installed NVIDIA and CUDA packages:"
|
|
||||||
dpkg -l | egrep "cuda|nvidia" -i
|
|
||||||
echo "DKMS Status:"
|
|
||||||
dkms status || echo "dkms not found"
|
|
||||||
echo "NVIDIA-SMI Status:"
|
|
||||||
nvidia-smi || echo "nvidia-smi not found"
|
|
||||||
24
.github/actions/setup-macos/action.yml
vendored
24
.github/actions/setup-macos/action.yml
vendored
@@ -1,24 +0,0 @@
|
|||||||
name: 'Setup macOS Environment'
|
|
||||||
description: 'Install dependencies for macOS builds'
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
python-version:
|
|
||||||
description: 'Python version to use'
|
|
||||||
required: false
|
|
||||||
default: '3.10'
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Install Homebrew packages
|
|
||||||
shell: sh
|
|
||||||
run: /opt/homebrew/bin/brew install openmpi
|
|
||||||
|
|
||||||
- name: Verify MetalToolchain installed
|
|
||||||
shell: bash
|
|
||||||
run: xcodebuild -showComponent MetalToolchain
|
|
||||||
|
|
||||||
- uses: conda-incubator/setup-miniconda@v3
|
|
||||||
with:
|
|
||||||
miniconda-version: "latest"
|
|
||||||
python-version: ${{ inputs.python-version }}
|
|
||||||
69
.github/actions/test-linux/action.yml
vendored
69
.github/actions/test-linux/action.yml
vendored
@@ -1,69 +0,0 @@
|
|||||||
name: 'Run Linux tests'
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
has-gpu:
|
|
||||||
description: 'Run GPU tests'
|
|
||||||
required: false
|
|
||||||
default: false
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Run MPI tests
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
echo "::group::MPI tests"
|
|
||||||
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
|
||||||
echo "::endgroup::"
|
|
||||||
|
|
||||||
- name: Run distributed tests
|
|
||||||
if: ${{ inputs.has-gpu == 'false' }}
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
echo "::group::Distributed tests"
|
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
|
||||||
if grep -Fq '[WARN]' stderr.log ; then
|
|
||||||
grep -F '[WARN]' stderr.log
|
|
||||||
echo "Distributed ring test failed";
|
|
||||||
exit 1;
|
|
||||||
fi
|
|
||||||
echo "::endgroup::"
|
|
||||||
|
|
||||||
- name: Run Python tests - CPU
|
|
||||||
if: ${{ inputs.has-gpu == 'false' }}
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
DEVICE: cpu
|
|
||||||
run: |
|
|
||||||
echo "::group::Python tests - CPU"
|
|
||||||
python -m unittest discover python/tests -v
|
|
||||||
echo "::endgroup::"
|
|
||||||
|
|
||||||
- name: Run Python tests - GPU
|
|
||||||
if: ${{ inputs.has-gpu == 'true' }}
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
DEVICE: gpu
|
|
||||||
run: |
|
|
||||||
echo "::group::Python tests - GPU"
|
|
||||||
python -m tests discover python/tests -v
|
|
||||||
echo "::endgroup::"
|
|
||||||
|
|
||||||
- name: Run CPP tests - CPU
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
DEVICE: cpu
|
|
||||||
run: |
|
|
||||||
echo "::group::CPP tests - CPU"
|
|
||||||
./build/tests/tests
|
|
||||||
echo "::endgroup::"
|
|
||||||
|
|
||||||
- name: Run CPP tests - GPU
|
|
||||||
if: ${{ inputs.has-gpu == 'true' }}
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
DEVICE: gpu
|
|
||||||
run: |
|
|
||||||
echo "::group::CPP tests - GPU"
|
|
||||||
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
|
||||||
echo "::endgroup::"
|
|
||||||
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -1,6 +0,0 @@
|
|||||||
version: 2
|
|
||||||
updates:
|
|
||||||
- package-ecosystem: "github-actions"
|
|
||||||
directory: "/"
|
|
||||||
schedule:
|
|
||||||
interval: "weekly"
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -ex
|
|
||||||
|
|
||||||
# [Setup] Install dependencies inside the container.
|
|
||||||
dnf update -y
|
|
||||||
dnf install -y \
|
|
||||||
blas-devel \
|
|
||||||
lapack-devel \
|
|
||||||
openblas-devel \
|
|
||||||
make \
|
|
||||||
cmake \
|
|
||||||
clang \
|
|
||||||
git
|
|
||||||
dnf clean all
|
|
||||||
|
|
||||||
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
|
|
||||||
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
|
||||||
export DEBUG=1
|
|
||||||
export CMAKE_C_COMPILER=/usr/bin/clang
|
|
||||||
export CMAKE_CXX_COMPILER=/usr/bin/clang++
|
|
||||||
|
|
||||||
mkdir -p build
|
|
||||||
pushd build
|
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
|
||||||
make -j $(nproc)
|
|
||||||
./tests/tests
|
|
||||||
popd
|
|
||||||
108
.github/workflows/build_and_test.yml
vendored
108
.github/workflows/build_and_test.yml
vendored
@@ -1,108 +0,0 @@
|
|||||||
name: Build and Test
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
# For testing CI without starting a pull request:
|
|
||||||
- test/*
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
|
||||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check_lint:
|
|
||||||
name: Check Lint
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: pre-commit/action@v3.0.1
|
|
||||||
|
|
||||||
linux_build_and_test:
|
|
||||||
name: Linux (cpu, ${{ matrix.arch }})
|
|
||||||
needs: check_lint
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
arch: ['x86_64', 'aarch64']
|
|
||||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
- uses: ./.github/actions/build-linux
|
|
||||||
- uses: ./.github/actions/test-linux
|
|
||||||
|
|
||||||
cuda_build_and_test:
|
|
||||||
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
needs: check_lint
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
arch: ['x86_64', 'aarch64']
|
|
||||||
toolkit: ['cuda-12.6', 'cuda-12.9']
|
|
||||||
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
with:
|
|
||||||
toolkit: ${{ matrix.toolkit }}
|
|
||||||
- uses: ./.github/actions/build-linux
|
|
||||||
with:
|
|
||||||
toolkit: ${{ matrix.toolkit }}
|
|
||||||
- uses: ./.github/actions/test-linux
|
|
||||||
if: matrix.arch == 'x86_64'
|
|
||||||
with:
|
|
||||||
has-gpu: true
|
|
||||||
|
|
||||||
mac_build_and_test:
|
|
||||||
name: macOS (${{ matrix.macos-target }})
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
macos-target: ["14.0", "15.0"]
|
|
||||||
runs-on: [self-hosted, macos]
|
|
||||||
env:
|
|
||||||
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
|
||||||
needs: check_lint
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-macos
|
|
||||||
- uses: ./.github/actions/build-macos
|
|
||||||
|
|
||||||
build_documentation:
|
|
||||||
name: Build Documentation
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
needs: check_lint
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/build-docs
|
|
||||||
|
|
||||||
linux_fedora_build_cpp:
|
|
||||||
name: Linux Fedora (${{ matrix.arch }})
|
|
||||||
needs: check_lint
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- host: ubuntu-22.04
|
|
||||||
arch: x86_64
|
|
||||||
- host: ubuntu-22.04-arm
|
|
||||||
arch: aarch64
|
|
||||||
|
|
||||||
runs-on: ${{ matrix.host }}
|
|
||||||
container:
|
|
||||||
image: fedora:42
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
|
|
||||||
- name: CPP Build Test - No Release
|
|
||||||
run: |
|
|
||||||
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
|
||||||
28
.github/workflows/documentation.yml
vendored
28
.github/workflows/documentation.yml
vendored
@@ -1,28 +0,0 @@
|
|||||||
name: Documentation
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/build-docs
|
|
||||||
|
|
||||||
deploy:
|
|
||||||
needs: build
|
|
||||||
permissions:
|
|
||||||
pages: write
|
|
||||||
id-token: write
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
environment:
|
|
||||||
name: github-pages
|
|
||||||
url: ${{ steps.deployment.outputs.page_url }}
|
|
||||||
steps:
|
|
||||||
- name: Deploy to GitHub Pages
|
|
||||||
id: deployment
|
|
||||||
uses: actions/deploy-pages@v4
|
|
||||||
96
.github/workflows/nightly.yml
vendored
96
.github/workflows/nightly.yml
vendored
@@ -1,96 +0,0 @@
|
|||||||
name: Nightly Build
|
|
||||||
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: 33 6 * * 1-5
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build_linux_release:
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
python_version: ["3.10", "3.14"]
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
- uses: ./.github/actions/build-linux-release
|
|
||||||
with:
|
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
|
||||||
arch: "x86_64"
|
|
||||||
- name: Upload mlx artifacts
|
|
||||||
uses: actions/upload-artifact@v5
|
|
||||||
with:
|
|
||||||
name: linux-wheels-${{ matrix.python_version }}
|
|
||||||
path: wheelhouse/mlx-*.whl
|
|
||||||
retention-days: 7
|
|
||||||
- name: Upload mlx-cpu artifacts
|
|
||||||
if: matrix.python_version == '3.10'
|
|
||||||
uses: actions/upload-artifact@v5
|
|
||||||
with:
|
|
||||||
name: mlx-cpu
|
|
||||||
path: wheelhouse/mlx_cpu-*.whl
|
|
||||||
retention-days: 7
|
|
||||||
|
|
||||||
build_linux_with_tests:
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
python_version: ["3.11", "3.12", "3.13", "3.14"]
|
|
||||||
runner:
|
|
||||||
- ubuntu-22.04
|
|
||||||
- ubuntu-22.04-arm
|
|
||||||
runs-on: ${{ matrix.runner }}
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python_version }}
|
|
||||||
- uses: ./.github/actions/build-linux
|
|
||||||
- uses: ./.github/actions/test-linux
|
|
||||||
|
|
||||||
build_mac_release:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.10", "3.13"]
|
|
||||||
runs-on: [self-hosted, macos]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-macos
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- uses: ./.github/actions/build-macos
|
|
||||||
- name: Build macOS 15 package
|
|
||||||
uses: ./.github/actions/build-macos-release
|
|
||||||
with:
|
|
||||||
macos-target: 15.0
|
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
|
||||||
- name: Build macOS 14 package
|
|
||||||
uses: ./.github/actions/build-macos-release
|
|
||||||
with:
|
|
||||||
macos-target: 14.0
|
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
|
||||||
|
|
||||||
build_cuda_release:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
runs-on: ubuntu-22-large
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
with:
|
|
||||||
toolkit: 'cuda-12.9'
|
|
||||||
- name: Build Python package
|
|
||||||
uses: ./.github/actions/build-cuda-release
|
|
||||||
with:
|
|
||||||
toolkit: 'cuda-12.9'
|
|
||||||
- name: Upload artifacts
|
|
||||||
uses: actions/upload-artifact@v5
|
|
||||||
with:
|
|
||||||
name: mlx-cuda
|
|
||||||
path: wheelhouse/mlx_cuda-*.whl
|
|
||||||
retention-days: 7
|
|
||||||
20
.github/workflows/pull_request.yml
vendored
Normal file
20
.github/workflows/pull_request.yml
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check_lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.8
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pre-commit black isort clang-format
|
||||||
|
- name: Run lint
|
||||||
|
run: |
|
||||||
|
pre-commit run --all-files
|
||||||
238
.github/workflows/release.yml
vendored
238
.github/workflows/release.yml
vendored
@@ -1,238 +0,0 @@
|
|||||||
name: PyPI Release
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
dev_release:
|
|
||||||
description: "Do a dev release or regular release"
|
|
||||||
required: true
|
|
||||||
default: "false"
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
setup:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Set publishing variables
|
|
||||||
run: echo "Publishing setup complete"
|
|
||||||
|
|
||||||
build_documentation:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/build-docs
|
|
||||||
|
|
||||||
deploy_documentation:
|
|
||||||
needs: build_documentation
|
|
||||||
permissions:
|
|
||||||
pages: write
|
|
||||||
id-token: write
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
environment:
|
|
||||||
name: github-pages
|
|
||||||
url: ${{ steps.deployment.outputs.page_url }}
|
|
||||||
steps:
|
|
||||||
- name: Deploy to GitHub Pages
|
|
||||||
id: deployment
|
|
||||||
uses: actions/deploy-pages@v4
|
|
||||||
|
|
||||||
build_linux_release:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
|
||||||
arch: ['x86_64', 'aarch64']
|
|
||||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
|
||||||
env:
|
|
||||||
PYPI_RELEASE: 1
|
|
||||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python_version }}
|
|
||||||
- uses: ./.github/actions/build-linux-release
|
|
||||||
with:
|
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
|
||||||
arch: ${{ matrix.arch }}
|
|
||||||
- name: Upload MLX artifacts
|
|
||||||
uses: actions/upload-artifact@v5
|
|
||||||
with:
|
|
||||||
overwrite: true
|
|
||||||
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
|
||||||
path: wheelhouse/mlx-*.whl
|
|
||||||
- name: Upload CPU artifacts
|
|
||||||
if: matrix.python_version == '3.10'
|
|
||||||
uses: actions/upload-artifact@v5
|
|
||||||
with:
|
|
||||||
overwrite: true
|
|
||||||
name: mlx-cpu-${{ matrix.arch }}
|
|
||||||
path: wheelhouse/mlx_cpu-*.whl
|
|
||||||
|
|
||||||
build_mac_release:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
|
||||||
runs-on: [self-hosted, macos]
|
|
||||||
env:
|
|
||||||
PYPI_RELEASE: 1
|
|
||||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-macos
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
shell: bash -l {0}
|
|
||||||
run: |
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install cmake setuptools nanobind==2.4.0
|
|
||||||
pip install -e . -v
|
|
||||||
- name: Generate package stubs
|
|
||||||
shell: bash -l {0}
|
|
||||||
run: |
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
- name: Build macOS 14 package
|
|
||||||
uses: ./.github/actions/build-macos-release
|
|
||||||
with:
|
|
||||||
macos-target: 14.0
|
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
|
||||||
- name: Build macOS 15 package
|
|
||||||
uses: ./.github/actions/build-macos-release
|
|
||||||
with:
|
|
||||||
macos-target: 15.0
|
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
|
||||||
- name: Upload MLX artifacts
|
|
||||||
uses: actions/upload-artifact@v5
|
|
||||||
with:
|
|
||||||
overwrite: true
|
|
||||||
name: mac-wheels-${{ matrix.python-version }}
|
|
||||||
path: dist/mlx-*.whl
|
|
||||||
- name: Upload Metal artifacts
|
|
||||||
if: matrix.python-version == '3.10'
|
|
||||||
uses: actions/upload-artifact@v5
|
|
||||||
with:
|
|
||||||
overwrite: true
|
|
||||||
name: mlx-metal
|
|
||||||
path: dist/mlx_metal-*.whl
|
|
||||||
|
|
||||||
build_cuda_release:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
runs-on: ubuntu-22-large
|
|
||||||
env:
|
|
||||||
PYPI_RELEASE: 1
|
|
||||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
with:
|
|
||||||
toolkit: 'cuda-12.9'
|
|
||||||
- name: Build Python package
|
|
||||||
uses: ./.github/actions/build-cuda-release
|
|
||||||
- name: Upload artifacts
|
|
||||||
uses: actions/upload-artifact@v5
|
|
||||||
with:
|
|
||||||
overwrite: true
|
|
||||||
name: mlx-cuda
|
|
||||||
path: wheelhouse/mlx_cuda-*.whl
|
|
||||||
|
|
||||||
pypi-publish:
|
|
||||||
name: Upload release to PyPI
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [setup, build_linux_release, build_mac_release]
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
environment:
|
|
||||||
name: pypi
|
|
||||||
url: https://pypi.org/p/mlx
|
|
||||||
steps:
|
|
||||||
- uses: actions/download-artifact@v6
|
|
||||||
with:
|
|
||||||
pattern: linux-wheels-*
|
|
||||||
merge-multiple: true
|
|
||||||
path: dist
|
|
||||||
- uses: actions/download-artifact@v6
|
|
||||||
with:
|
|
||||||
pattern: mac-wheels-*
|
|
||||||
merge-multiple: true
|
|
||||||
path: dist
|
|
||||||
- name: Display structure of downloaded files
|
|
||||||
run: ls -R dist
|
|
||||||
- name: Publish package distributions to PyPI
|
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
|
||||||
with:
|
|
||||||
repository-url: https://upload.pypi.org/legacy/
|
|
||||||
|
|
||||||
pypi-publish-cuda:
|
|
||||||
name: Upload CUDA release to PyPI
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [setup, build_cuda_release]
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
environment:
|
|
||||||
name: pypi
|
|
||||||
url: https://pypi.org/p/mlx-cuda
|
|
||||||
steps:
|
|
||||||
- uses: actions/download-artifact@v6
|
|
||||||
with:
|
|
||||||
name: mlx-cuda
|
|
||||||
path: dist
|
|
||||||
- name: Display structure of downloaded files
|
|
||||||
run: ls -R dist
|
|
||||||
- name: Publish package distributions to PyPI
|
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
|
||||||
with:
|
|
||||||
repository-url: https://upload.pypi.org/legacy/
|
|
||||||
|
|
||||||
pypi-publish-cpu:
|
|
||||||
name: Upload CPU release to PyPI
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [setup, build_linux_release]
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
environment:
|
|
||||||
name: pypi
|
|
||||||
url: https://pypi.org/p/mlx-cpu
|
|
||||||
steps:
|
|
||||||
- uses: actions/download-artifact@v6
|
|
||||||
with:
|
|
||||||
pattern: mlx-cpu-*
|
|
||||||
merge-multiple: true
|
|
||||||
path: dist
|
|
||||||
- name: Display structure of downloaded files
|
|
||||||
run: ls -R dist
|
|
||||||
- name: Publish package distributions to PyPI
|
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
|
||||||
with:
|
|
||||||
repository-url: https://upload.pypi.org/legacy/
|
|
||||||
|
|
||||||
pypi-publish-metal:
|
|
||||||
name: Upload Metal release to PyPI
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [setup, build_mac_release]
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
environment:
|
|
||||||
name: pypi
|
|
||||||
url: https://pypi.org/p/mlx-metal
|
|
||||||
steps:
|
|
||||||
- uses: actions/download-artifact@v6
|
|
||||||
with:
|
|
||||||
name: mlx-metal
|
|
||||||
path: dist
|
|
||||||
- name: Display structure of downloaded files
|
|
||||||
run: ls -R dist
|
|
||||||
- name: Publish package distributions to PyPI
|
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
|
||||||
with:
|
|
||||||
repository-url: https://upload.pypi.org/legacy/
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,7 +36,6 @@ share/python-wheels/
|
|||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
uv.lock
|
|
||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|||||||
@@ -1,10 +1,4 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v6.0.0
|
|
||||||
hooks:
|
|
||||||
- id: check-yaml
|
|
||||||
# - id: end-of-file-fixer
|
|
||||||
# - id: trailing-whitespace
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v19.1.7
|
rev: v19.1.7
|
||||||
hooks:
|
hooks:
|
||||||
|
|||||||
@@ -19,17 +19,11 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
# Organizations
|
|
||||||
|
|
||||||
MLX has received contributions from the following companies:
|
|
||||||
- NVIDIA Corporation & Affiliates
|
|
||||||
|
|
||||||
# Third-Party Software
|
# Third-Party Software
|
||||||
|
|
||||||
MLX leverages several third-party software, listed here together with
|
MLX leverages several third-party software, listed here together with
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ set(CMAKE_CXX_STANDARD 17)
|
|||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
||||||
|
|
||||||
# ----------------------------- Configuration -----------------------------
|
# ----------------------------- Configuration -----------------------------
|
||||||
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
||||||
@@ -35,16 +34,13 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
|||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
|
||||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
message(
|
message(
|
||||||
@@ -67,18 +63,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
endif()
|
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||||
|
|
||||||
if(MLX_USE_CCACHE)
|
|
||||||
find_program(CCACHE_PROGRAM ccache)
|
|
||||||
if(CCACHE_PROGRAM)
|
|
||||||
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
|
|
||||||
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
|
||||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
|
||||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
|
||||||
endif()
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Lib -----------------------------
|
# ----------------------------- Lib -----------------------------
|
||||||
@@ -89,26 +77,18 @@ cmake_policy(SET CMP0135 NEW)
|
|||||||
|
|
||||||
add_library(mlx)
|
add_library(mlx)
|
||||||
|
|
||||||
# Supress warnings: note: parameter passing for argument of type
|
if(MLX_BUILD_METAL)
|
||||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
set(METAL_LIB "-framework Metal")
|
||||||
# 10.1
|
set(FOUNDATION_LIB "-framework Foundation")
|
||||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
set(QUARTZ_LIB "-framework QuartzCore")
|
||||||
|
|
||||||
if(MLX_BUILD_CUDA)
|
|
||||||
enable_language(CUDA)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
find_library(METAL_LIB Metal)
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
find_library(FOUNDATION_LIB Foundation)
|
set(MLX_BUILD_METAL OFF)
|
||||||
find_library(QUARTZ_LIB QuartzCore)
|
set(MLX_METAL_DEBUG OFF)
|
||||||
if(METAL_LIB)
|
elseif(MLX_BUILD_METAL)
|
||||||
message(STATUS "Metal found ${METAL_LIB}")
|
message(STATUS "Building METAL sources")
|
||||||
else()
|
|
||||||
message(
|
|
||||||
FATAL_ERROR
|
|
||||||
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
add_compile_definitions(MLX_METAL_DEBUG)
|
add_compile_definitions(MLX_METAL_DEBUG)
|
||||||
@@ -117,12 +97,7 @@ if(MLX_BUILD_METAL)
|
|||||||
# Throw an error if xcrun not found
|
# Throw an error if xcrun not found
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
|
||||||
execute_process(
|
|
||||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
|
|
||||||
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
|
||||||
|
|
||||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||||
message(
|
message(
|
||||||
@@ -132,12 +107,9 @@ if(MLX_BUILD_METAL)
|
|||||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||||
|
|
||||||
set(METAL_CPP_URL
|
set(METAL_CPP_URL
|
||||||
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
|
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||||
|
|
||||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||||
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
|
|
||||||
message(FATAL_ERROR "MLX requires macOS >= 14.0")
|
|
||||||
endif()
|
|
||||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
endif()
|
endif()
|
||||||
execute_process(
|
execute_process(
|
||||||
@@ -146,6 +118,7 @@ if(MLX_BUILD_METAL)
|
|||||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||||
|
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||||
@@ -153,12 +126,6 @@ if(MLX_BUILD_METAL)
|
|||||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
|
||||||
# With newer clang/gcc versions following libs are implicitly linked, but when
|
|
||||||
# building on old distributions they need to be explicitly listed.
|
|
||||||
target_link_libraries(mlx PRIVATE dl pthread)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
# GGUF does not build with MSVC.
|
# GGUF does not build with MSVC.
|
||||||
@@ -186,7 +153,7 @@ if(MLX_BUILD_CPU)
|
|||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
else()
|
else()
|
||||||
message(STATUS "Accelerate not found, using default backend.")
|
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -259,19 +226,12 @@ target_include_directories(
|
|||||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
$<INSTALL_INTERFACE:include>)
|
$<INSTALL_INTERFACE:include>)
|
||||||
|
|
||||||
# Do not add mlx_EXPORTS define for shared library.
|
FetchContent_Declare(
|
||||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
|
||||||
|
|
||||||
if(USE_SYSTEM_FMT)
|
|
||||||
find_package(fmt REQUIRED)
|
|
||||||
else()
|
|
||||||
FetchContent_Declare(
|
|
||||||
fmt
|
fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
GIT_TAG 10.2.1
|
GIT_TAG 10.2.1
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL)
|
||||||
FetchContent_MakeAvailable(fmt)
|
FetchContent_MakeAvailable(fmt)
|
||||||
endif()
|
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
include mlx.pc.in
|
|
||||||
recursive-include mlx/ *
|
recursive-include mlx/ *
|
||||||
include cmake/*
|
|
||||||
include python/src/*
|
include python/src/*
|
||||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||||
|
|||||||
31
README.md
31
README.md
@@ -11,28 +11,28 @@ brought to you by Apple machine learning research.
|
|||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
more complex models.
|
more complex models.
|
||||||
|
|
||||||
- **Composable function transformations**: MLX supports composable function
|
- **Composable function transformations**: MLX supports composable function
|
||||||
transformations for automatic differentiation, automatic vectorization,
|
transformations for automatic differentiation, automatic vectorization,
|
||||||
and computation graph optimization.
|
and computation graph optimization.
|
||||||
|
|
||||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||||
materialized when needed.
|
materialized when needed.
|
||||||
|
|
||||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||||
dynamically. Changing the shapes of function arguments does not trigger
|
dynamically. Changing the shapes of function arguments does not trigger
|
||||||
slow compilations, and debugging is simple and intuitive.
|
slow compilations, and debugging is simple and intuitive.
|
||||||
|
|
||||||
- **Multi-device**: Operations can run on any of the supported devices
|
- **Multi-device**: Operations can run on any of the supported devices
|
||||||
(currently the CPU and the GPU).
|
(currently the CPU and the GPU).
|
||||||
|
|
||||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||||
Operations on MLX arrays can be performed on any of the supported
|
Operations on MLX arrays can be performed on any of the supported
|
||||||
device types without transferring data.
|
device types without transferring data.
|
||||||
@@ -68,23 +68,18 @@ in the documentation.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||||
macOS, run:
|
|
||||||
|
|
||||||
```bash
|
**With `pip`**:
|
||||||
|
|
||||||
|
```
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
To install the CUDA backend on Linux, run:
|
**With `conda`**:
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install mlx[cuda]
|
|
||||||
```
|
```
|
||||||
|
conda install -c conda-forge mlx
|
||||||
To install a CPU-only Linux package, run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install mlx[cpu]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
@@ -110,7 +105,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
|||||||
MLX useful in your research and wish to cite it, please use the following
|
MLX useful in your research and wish to cite it, please use the following
|
||||||
BibTex entry:
|
BibTex entry:
|
||||||
|
|
||||||
```text
|
```
|
||||||
@software{mlx2023,
|
@software{mlx2023,
|
||||||
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||||
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <cstring>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@@ -75,7 +74,7 @@ void time_irregular_binary_ops_3D() {
|
|||||||
|
|
||||||
void time_irregular_binary_ops_4D() {
|
void time_irregular_binary_ops_4D() {
|
||||||
auto device = mx::default_device();
|
auto device = mx::default_device();
|
||||||
mx::Shape shape = {8, 8, 512, 512};
|
std::vector<int> shape = {8, 8, 512, 512};
|
||||||
auto a = mx::random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
auto b = mx::random::uniform(shape);
|
auto b = mx::random::uniform(shape);
|
||||||
|
|
||||||
@@ -115,7 +114,7 @@ void time_irregular_binary_ops_4D() {
|
|||||||
|
|
||||||
void time_irregular_reshape() {
|
void time_irregular_reshape() {
|
||||||
auto device = mx::default_device();
|
auto device = mx::default_device();
|
||||||
mx::Shape shape;
|
std::vector<int> shape;
|
||||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||||
return mx::reshape(a, shape, device);
|
return mx::reshape(a, shape, device);
|
||||||
};
|
};
|
||||||
@@ -170,7 +169,7 @@ void time_irregular_astype_1D() {
|
|||||||
void time_irregular_astype_2D() {
|
void time_irregular_astype_2D() {
|
||||||
auto device = mx::default_device();
|
auto device = mx::default_device();
|
||||||
int size = 2048;
|
int size = 2048;
|
||||||
mx::Shape shape = {size, size};
|
std::vector<int> shape = {size, size};
|
||||||
|
|
||||||
auto a = mx::random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||||
|
|||||||
@@ -192,22 +192,6 @@ void time_reductions() {
|
|||||||
|
|
||||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||||
TIME(argmin_along_1);
|
TIME(argmin_along_1);
|
||||||
|
|
||||||
auto indices = mx::array({1});
|
|
||||||
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
|
||||||
std::vector<int> axes{0};
|
|
||||||
auto b = scatter(a, {indices}, updates, axes);
|
|
||||||
mx::eval(b);
|
|
||||||
|
|
||||||
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
|
||||||
TIME(max_along_0);
|
|
||||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
|
||||||
TIME(max_along_1);
|
|
||||||
|
|
||||||
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
|
||||||
TIME(min_along_0);
|
|
||||||
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
|
||||||
TIME(min_along_1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_gather_scatter() {
|
void time_gather_scatter() {
|
||||||
|
|||||||
@@ -142,7 +142,9 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
|||||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||||
|
|
||||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||||
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
|
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
||||||
|
np.float32
|
||||||
|
)
|
||||||
|
|
||||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
@@ -155,13 +157,13 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
|||||||
|
|
||||||
|
|
||||||
def get_gflop_count(B, M, N, K):
|
def get_gflop_count(B, M, N, K):
|
||||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1000.0**3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
|
|
||||||
dtypes = ("float32", "float16", "complex64")
|
dtypes = ("float32", "float16")
|
||||||
transposes = ("nn", "nt", "tn")
|
transposes = ("nn", "nt", "tn")
|
||||||
shapes = (
|
shapes = (
|
||||||
(16, 234, 768, 3072),
|
(16, 234, 768, 3072),
|
||||||
@@ -173,6 +175,8 @@ if __name__ == "__main__":
|
|||||||
(1, 4096, 4096, 4096),
|
(1, 4096, 4096, 4096),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(f" B, M, N, K, dtype, t, gflops_pt, gflops_mx, diff%")
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
for transpose in transposes:
|
for transpose in transposes:
|
||||||
for B, M, N, K in shapes:
|
for B, M, N, K in shapes:
|
||||||
@@ -185,7 +189,7 @@ if __name__ == "__main__":
|
|||||||
diff = gflops_mx / gflops_pt - 1.0
|
diff = gflops_mx / gflops_pt - 1.0
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
|
f"{B:3d}, {M:4d}, {N:4d}, {K:5d}, {dtype}, {transpose}, {gflops_pt:8.2f}, {gflops_mx:8.2f}, {100. * diff:+5.2f}%"
|
||||||
)
|
)
|
||||||
if gflops_pt >= 2.0 * gflops_mx:
|
if gflops_pt >= 2.0 * gflops_mx:
|
||||||
print("ATTENTION ^^^^^^^")
|
print("ATTENTION ^^^^^^^")
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
@@ -195,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
|||||||
|
|
||||||
|
|
||||||
for transpose in (False, True):
|
for transpose in (False, True):
|
||||||
for dtype in ("float32", "float16", "complex64"):
|
for dtype in ("float32", "float16"):
|
||||||
fig, axs = plt.subplots(
|
fig, axs = plt.subplots(
|
||||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||||
)
|
)
|
||||||
@@ -214,7 +215,7 @@ for transpose in (False, True):
|
|||||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||||
fig.savefig(
|
fig.savefig(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
|
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@@ -45,10 +44,8 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device == torch.device("mps"):
|
if x.device != torch.device("cpu"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
elif x.device == torch.device("cuda"):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sum_and_add(axis, x, y):
|
|
||||||
z = x.sum(axis=axis, keepdims=True)
|
|
||||||
for i in range(50):
|
|
||||||
z = (z + y).sum(axis=axis, keepdims=True)
|
|
||||||
sync_if_needed(x)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@@ -351,11 +340,7 @@ if __name__ == "__main__":
|
|||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "mps"
|
device = "cpu" if args.cpu else "mps"
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = "cuda"
|
|
||||||
if args.cpu:
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@@ -475,8 +460,5 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
elif args.benchmark == "sum_and_add":
|
|
||||||
print(bench(sum_and_add, axis, *xs))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
|||||||
@@ -1,107 +0,0 @@
|
|||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
N_warmup = 10
|
|
||||||
N_iter_bench = 100
|
|
||||||
N_iter_func = 5
|
|
||||||
|
|
||||||
|
|
||||||
def bench(f, a, b):
|
|
||||||
for i in range(N_warmup):
|
|
||||||
f(a, b)
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
s = time.perf_counter_ns()
|
|
||||||
for i in range(N_iter_bench):
|
|
||||||
f(a, b)
|
|
||||||
e = time.perf_counter_ns()
|
|
||||||
return (e - s) * 1e-9
|
|
||||||
|
|
||||||
|
|
||||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
def mx_conv_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
|
||||||
ys.append(y)
|
|
||||||
mx.eval(ys)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return mx_conv_2D
|
|
||||||
|
|
||||||
|
|
||||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
@torch.no_grad()
|
|
||||||
def pt_conv_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
|
||||||
ys.append(y)
|
|
||||||
torch.mps.synchronize()
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return pt_conv_2D
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
|
||||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
|
||||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
|
||||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
|
||||||
np_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
a_mx = mx.array(a_np)
|
|
||||||
b_mx = mx.array(b_np)
|
|
||||||
|
|
||||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
|
||||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
|
||||||
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
|
||||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
|
||||||
|
|
||||||
time_torch = bench(f_pt, a_pt, b_pt)
|
|
||||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
|
||||||
|
|
||||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
|
||||||
out_pt = torch.conv2d(
|
|
||||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
|
||||||
out_pt = out_pt.numpy(force=True)
|
|
||||||
|
|
||||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
|
||||||
print(
|
|
||||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return time_mlx, time_torch
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
dtype = "float32"
|
|
||||||
shapes = (
|
|
||||||
(4, 32, 32, 21, 3, 3, 128),
|
|
||||||
(4, 32, 32, 21, 3, 3, 37),
|
|
||||||
(4, 32, 32, 370, 3, 3, 370),
|
|
||||||
(4, 32, 32, 370, 7, 7, 128),
|
|
||||||
(2, 320, 640, 21, 7, 7, 21),
|
|
||||||
)
|
|
||||||
for N, H, W, C, kh, kw, O in shapes:
|
|
||||||
time_mlx, time_torch = bench_shape(
|
|
||||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
|
||||||
)
|
|
||||||
diff = time_torch / time_mlx - 1.0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
|
||||||
)
|
|
||||||
if time_mlx >= 2.0 * time_torch:
|
|
||||||
print("ATTENTION ^^^^^^^")
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2025 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|||||||
@@ -1,84 +0,0 @@
|
|||||||
# Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from time_utils import time_fn
|
|
||||||
|
|
||||||
N = 1024
|
|
||||||
D = 1024
|
|
||||||
M = 1024
|
|
||||||
E = 32
|
|
||||||
I = 4
|
|
||||||
|
|
||||||
|
|
||||||
def gather_sort(x, indices):
|
|
||||||
N, M = indices.shape
|
|
||||||
indices = indices.flatten()
|
|
||||||
order = mx.argsort(indices)
|
|
||||||
inv_order = mx.argsort(order)
|
|
||||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
|
||||||
|
|
||||||
|
|
||||||
def scatter_unsort(x, inv_order, shape=None):
|
|
||||||
x = x[inv_order]
|
|
||||||
if shape is not None:
|
|
||||||
x = mx.unflatten(x, 0, shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gather_mm_simulate(x, w, indices):
|
|
||||||
x, idx, inv_order = gather_sort(x, indices)
|
|
||||||
for i in range(2):
|
|
||||||
y = mx.concatenate(
|
|
||||||
[
|
|
||||||
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
|
||||||
for i, j in enumerate(idx.tolist())
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
x = y[:, None]
|
|
||||||
x = scatter_unsort(x, inv_order, indices.shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def time_gather_qmm():
|
|
||||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
|
||||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
|
||||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
|
||||||
w1 = mx.quantize(w1)
|
|
||||||
w2 = mx.quantize(w2)
|
|
||||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
|
||||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
|
||||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
|
||||||
|
|
||||||
def gather_mm(x, w1, w2, indices, sort):
|
|
||||||
idx = indices
|
|
||||||
inv_order = None
|
|
||||||
if sort:
|
|
||||||
x, idx, inv_order = gather_sort(x, indices)
|
|
||||||
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
|
||||||
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
|
||||||
if sort:
|
|
||||||
x = scatter_unsort(x, inv_order, indices.shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
|
||||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
|
||||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
|
||||||
|
|
||||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
|
||||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
|
||||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
|
||||||
w1 = mx.quantize(w1)
|
|
||||||
w2 = mx.quantize(w2)
|
|
||||||
mx.eval(x, w1, w2)
|
|
||||||
|
|
||||||
def equivalent_matmul(x, w1, w2):
|
|
||||||
x = mx.quantized_matmul(x, *w1, transpose=True)
|
|
||||||
x = mx.quantized_matmul(x, *w2, transpose=True)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(equivalent_matmul, x, w1, w2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
time_gather_qmm()
|
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
@@ -20,63 +18,51 @@ def layer_norm(x, w, b, eps):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_layer_norm(N, dt):
|
def time_layer_norm():
|
||||||
L = 1024
|
|
||||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(f, x, w, b):
|
def layer_norm_loop(g, x, w, b):
|
||||||
for _ in range(32):
|
|
||||||
x = f(x, w, b)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
|
||||||
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
|
||||||
|
|
||||||
def layer_norm_grad_loop(g, x, w, b):
|
|
||||||
gx, gw, gb = x, w, b
|
gx, gw, gb = x, w, b
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx, gw, gb = g(gx, gw, gb, y)
|
gx, gw, gb = g(gx, gw, gb, y)
|
||||||
return gx, gw, gb
|
return gx, gw, gb
|
||||||
|
|
||||||
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
time_fn(layer_norm_loop, g1, x, w, b)
|
||||||
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
time_fn(layer_norm_loop, g2, x, w, b)
|
||||||
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||||
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0,))
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
g2 = mx.grad(f2, argnums=(0,))
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_grad_x_loop(g, x):
|
def layer_norm_loop(g, x):
|
||||||
gx = x
|
gx = x
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx = g(gx, y)
|
gx = g(gx, y)
|
||||||
return gx
|
return gx
|
||||||
|
|
||||||
time_fn(layer_norm_grad_x_loop, g1, x)
|
time_fn(layer_norm_loop, g1, x)
|
||||||
time_fn(layer_norm_grad_x_loop, g2, x)
|
time_fn(layer_norm_loop, g2, x)
|
||||||
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||||
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
time_layer_norm()
|
||||||
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
|
||||||
print(dt, n)
|
|
||||||
time_layer_norm(n, dt)
|
|
||||||
|
|||||||
@@ -1,212 +0,0 @@
|
|||||||
import math
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
from copy import copy
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from matplotlib.ticker import FuncFormatter
|
|
||||||
|
|
||||||
RESULTS_DIR = "./results"
|
|
||||||
|
|
||||||
|
|
||||||
if not os.path.isdir(RESULTS_DIR):
|
|
||||||
os.mkdir(RESULTS_DIR)
|
|
||||||
|
|
||||||
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
|
||||||
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
|
|
||||||
|
|
||||||
TORCH_DEVICE = torch.device(
|
|
||||||
"mps"
|
|
||||||
if torch.backends.mps.is_available()
|
|
||||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
N_WARMUP = 5
|
|
||||||
N_ITER_BENCH = 50
|
|
||||||
N_ITER_FUNC = 20
|
|
||||||
|
|
||||||
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
|
|
||||||
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
|
|
||||||
D_TYPES = ("float32", "float16")
|
|
||||||
|
|
||||||
|
|
||||||
def _power_of_two_formatter(value, _position):
|
|
||||||
if value <= 0:
|
|
||||||
return ""
|
|
||||||
exponent = int(round(math.log2(value)))
|
|
||||||
if abs(value - (1 << exponent)) / value > 1e-6:
|
|
||||||
return f"{value:g}"
|
|
||||||
return f"$2^{{{exponent}}}$"
|
|
||||||
|
|
||||||
|
|
||||||
def torch_sync():
|
|
||||||
if TORCH_DEVICE.type == "cuda":
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elif TORCH_DEVICE.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
|
|
||||||
outs = []
|
|
||||||
for _ in range(N_ITER_FUNC):
|
|
||||||
out = copy(self_arr)
|
|
||||||
out[mask_arr] = src_arr
|
|
||||||
outs.append(out)
|
|
||||||
mx.eval(outs)
|
|
||||||
return outs
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
|
|
||||||
outs = []
|
|
||||||
for _ in range(N_ITER_FUNC):
|
|
||||||
out = self_tensor.clone()
|
|
||||||
out.masked_scatter_(mask_tensor, src_tensor)
|
|
||||||
outs.append(out)
|
|
||||||
torch_sync()
|
|
||||||
return outs
|
|
||||||
|
|
||||||
|
|
||||||
def measure(fn):
|
|
||||||
for _ in range(N_WARMUP):
|
|
||||||
fn()
|
|
||||||
start = time.perf_counter_ns()
|
|
||||||
for _ in range(N_ITER_BENCH):
|
|
||||||
fn()
|
|
||||||
end = time.perf_counter_ns()
|
|
||||||
return (end - start) * 1e-9
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_touched(length, true_count, item_size):
|
|
||||||
mask_bytes = length
|
|
||||||
self_bytes = length * item_size * 2 # read + write
|
|
||||||
src_bytes = true_count * item_size
|
|
||||||
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
|
|
||||||
|
|
||||||
|
|
||||||
def build_case(length, density, np_dtype, torch_dtype):
|
|
||||||
true_count = max(1, int(round(length * density)))
|
|
||||||
|
|
||||||
rng = np.random.default_rng()
|
|
||||||
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
|
|
||||||
mask_np = np.zeros(length, dtype=bool)
|
|
||||||
mask_np[:true_count] = True
|
|
||||||
rng.shuffle(mask_np)
|
|
||||||
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
|
|
||||||
|
|
||||||
self_mlx = mx.array(self_np)
|
|
||||||
mask_mlx = mx.array(mask_np)
|
|
||||||
src_mlx = mx.array(src_np)
|
|
||||||
|
|
||||||
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
|
||||||
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
|
|
||||||
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
|
||||||
|
|
||||||
# Correctness check once per configuration
|
|
||||||
mx_out = mx.array(self_np)
|
|
||||||
mx_out[mask_mlx] = src_mlx
|
|
||||||
mx.eval(mx_out)
|
|
||||||
torch_out = self_torch.clone()
|
|
||||||
torch_out.masked_scatter_(mask_torch, src_torch)
|
|
||||||
|
|
||||||
atol = 5e-3 if np_dtype == np.float16 else 1e-5
|
|
||||||
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
|
|
||||||
raise AssertionError("masked_scatter results diverged between MLX and Torch")
|
|
||||||
|
|
||||||
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
|
|
||||||
|
|
||||||
|
|
||||||
def bench_case(length, density, dtype):
|
|
||||||
np_dtype = getattr(np, dtype)
|
|
||||||
torch_dtype = getattr(torch, dtype)
|
|
||||||
(
|
|
||||||
self_mlx,
|
|
||||||
mask_mlx,
|
|
||||||
src_mlx,
|
|
||||||
self_torch,
|
|
||||||
mask_torch,
|
|
||||||
src_torch,
|
|
||||||
true_count,
|
|
||||||
) = build_case(length, density, np_dtype, torch_dtype)
|
|
||||||
|
|
||||||
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
|
|
||||||
time_torch = measure(
|
|
||||||
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
|
|
||||||
)
|
|
||||||
|
|
||||||
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
|
|
||||||
bytes_per_gb = float(1024**3)
|
|
||||||
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
|
|
||||||
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
|
|
||||||
|
|
||||||
return time_mlx, time_torch, mlx_gbps, torch_gbps
|
|
||||||
|
|
||||||
|
|
||||||
def plot_density(ax_perf, ax_speedup, density, dtype):
|
|
||||||
mlx_gbps = []
|
|
||||||
torch_gbps = []
|
|
||||||
mlx_times = []
|
|
||||||
torch_times = []
|
|
||||||
|
|
||||||
for length in VECTOR_LENGTHS:
|
|
||||||
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
|
|
||||||
mlx_gbps.append(gbps_mlx)
|
|
||||||
torch_gbps.append(gbps_torch)
|
|
||||||
mlx_times.append(t_mlx)
|
|
||||||
torch_times.append(t_torch)
|
|
||||||
|
|
||||||
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
|
|
||||||
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
|
|
||||||
ax_perf.set_xscale("log", base=2)
|
|
||||||
ax_perf.set_xticks(VECTOR_LENGTHS)
|
|
||||||
formatter = FuncFormatter(_power_of_two_formatter)
|
|
||||||
ax_perf.xaxis.set_major_formatter(formatter)
|
|
||||||
ax_perf.set_title(f"density={density:.2f}")
|
|
||||||
ax_perf.set_ylabel("GB/s")
|
|
||||||
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
|
|
||||||
ax_perf.legend()
|
|
||||||
|
|
||||||
speedup = np.array(torch_times) / np.array(mlx_times)
|
|
||||||
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
|
|
||||||
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
|
|
||||||
ax_speedup.set_xscale("log", base=2)
|
|
||||||
ax_speedup.set_xticks(VECTOR_LENGTHS)
|
|
||||||
ax_speedup.xaxis.set_major_formatter(formatter)
|
|
||||||
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
|
|
||||||
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
for dtype in D_TYPES:
|
|
||||||
fig, axs = plt.subplots(
|
|
||||||
len(MASK_DENSITIES),
|
|
||||||
2,
|
|
||||||
figsize=(10, 12),
|
|
||||||
layout="constrained",
|
|
||||||
sharex=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, density in enumerate(MASK_DENSITIES):
|
|
||||||
plot_density(axs[i][0], axs[i][1], density, dtype)
|
|
||||||
axs[i][0].set_xlabel("vector length")
|
|
||||||
axs[i][1].set_xlabel("vector length")
|
|
||||||
|
|
||||||
fig.suptitle(
|
|
||||||
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
|
|
||||||
)
|
|
||||||
output_path = os.path.join(
|
|
||||||
RESULTS_DIR,
|
|
||||||
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
|
|
||||||
)
|
|
||||||
fig.savefig(output_path)
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -51,20 +51,6 @@ def time_maximum():
|
|||||||
time_fn(mx.maximum, a, b)
|
time_fn(mx.maximum, a, b)
|
||||||
|
|
||||||
|
|
||||||
def time_max():
|
|
||||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
|
||||||
a[1, 1] = mx.nan
|
|
||||||
mx.eval(a)
|
|
||||||
time_fn(mx.max, a, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def time_min():
|
|
||||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
|
||||||
a[1, 1] = mx.nan
|
|
||||||
mx.eval(a)
|
|
||||||
time_fn(mx.min, a, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -122,8 +108,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
time_min()
|
|
||||||
time_max()
|
|
||||||
time_maximum()
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
time_negative()
|
time_negative()
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
|
||||||
# directories.
|
|
||||||
|
|
||||||
set(NCCL_ROOT_DIR
|
|
||||||
$ENV{NCCL_ROOT_DIR}
|
|
||||||
CACHE PATH "Folder contains NVIDIA NCCL")
|
|
||||||
|
|
||||||
find_path(
|
|
||||||
NCCL_INCLUDE_DIRS
|
|
||||||
NAMES nccl.h
|
|
||||||
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
|
||||||
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
|
||||||
|
|
||||||
if($ENV{USE_STATIC_NCCL})
|
|
||||||
message(
|
|
||||||
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
|
||||||
set(NCCL_LIBNAME "libnccl_static.a")
|
|
||||||
else()
|
|
||||||
set(NCCL_LIBNAME "nccl")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_library(
|
|
||||||
NCCL_LIBRARIES
|
|
||||||
NAMES ${NCCL_LIBNAME}
|
|
||||||
HINTS ${NCCL_LIB_DIR}
|
|
||||||
${NCCL_ROOT_DIR}
|
|
||||||
${NCCL_ROOT_DIR}/lib
|
|
||||||
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
|
||||||
${NCCL_ROOT_DIR}/lib64
|
|
||||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
|
||||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
|
||||||
|
|
||||||
include(FindPackageHandleStandardArgs)
|
|
||||||
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
|
||||||
NCCL_LIBRARIES)
|
|
||||||
|
|
||||||
if(NCCL_FOUND)
|
|
||||||
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
|
||||||
message(
|
|
||||||
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
|
||||||
file(
|
|
||||||
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
|
||||||
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
|
||||||
LIMIT_COUNT 1)
|
|
||||||
if(NCCL_MAJOR_VERSION_DEFINED)
|
|
||||||
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
|
||||||
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
|
||||||
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
|
||||||
endif()
|
|
||||||
message(
|
|
||||||
STATUS
|
|
||||||
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
|
||||||
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
|
||||||
endif()
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# This file does nothing but to suppress the cmake warning: "By not providing
|
|
||||||
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
|
|
||||||
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.
|
|
||||||
@@ -11,14 +11,13 @@ include(CMakeParseArguments)
|
|||||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
# files (like headers)
|
||||||
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
|
||||||
#
|
#
|
||||||
# clang format on
|
# clang format on
|
||||||
|
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
@@ -27,10 +26,6 @@ macro(mlx_build_metallib)
|
|||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
|
||||||
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
|
||||||
-frecord-sources)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
sphinx
|
sphinx
|
||||||
breathe
|
breathe
|
||||||
sphinx-book-theme
|
sphinx-book-theme
|
||||||
sphinx-copybutton
|
|
||||||
mlx
|
mlx
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, Apple"
|
copyright = "2023, MLX Contributors"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = ".".join(mx.__version__.split(".")[:3])
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = version
|
release = version
|
||||||
@@ -18,7 +18,6 @@ release = version
|
|||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
"sphinx_copybutton",
|
|
||||||
"sphinx.ext.autodoc",
|
"sphinx.ext.autodoc",
|
||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
|
|||||||
@@ -8,12 +8,11 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|||||||
Simple Example
|
Simple Example
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
.. currentmodule:: mlx.core
|
|
||||||
|
|
||||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
T tmp = inp[elem];
|
T tmp = inp[elem];
|
||||||
@@ -26,8 +25,6 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -42,13 +39,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
Every time you make a kernel, a new Metal library is created and possibly
|
|
||||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
|
||||||
:func:`fast.metal_kernel` and then use it many times.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Only pass the body of the Metal kernel in ``source``. The function
|
We are only required to pass the body of the Metal kernel in ``source``.
|
||||||
signature is generated automatically.
|
|
||||||
|
|
||||||
The full function signature will be generated using:
|
The full function signature will be generated using:
|
||||||
|
|
||||||
@@ -86,34 +78,29 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
|||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
|
||||||
dimension should be less than or equal to the corresponding grid dimension.
|
|
||||||
|
|
||||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||||
generated code for debugging purposes.
|
|
||||||
|
|
||||||
Using Shape/Strides
|
Using Shape/Strides
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||||
is ``True`` by default. This will copy the array inputs if needed
|
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||||
before the kernel is launched to ensure that the memory layout is row
|
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
when indexing.
|
||||||
have to worry about gaps or the ordering of the dims when indexing.
|
|
||||||
|
|
||||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
input array ``a`` if any are present in ``source``.
|
||||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||||
the right elements for each thread.
|
|
||||||
|
|
||||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||||
relying on a copy from ``ensure_row_contiguous``:
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
@@ -127,11 +114,8 @@ relying on a copy from ``ensure_row_contiguous``:
|
|||||||
name="myexp_strided",
|
name="myexp_strided",
|
||||||
input_names=["inp"],
|
input_names=["inp"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source
|
||||||
ensure_row_contiguous=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -139,6 +123,7 @@ relying on a copy from ``ensure_row_contiguous``:
|
|||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
output_shapes=[a.shape],
|
output_shapes=[a.shape],
|
||||||
output_dtypes=[a.dtype],
|
output_dtypes=[a.dtype],
|
||||||
|
ensure_row_contiguous=False,
|
||||||
)
|
)
|
||||||
return outputs[0]
|
return outputs[0]
|
||||||
|
|
||||||
@@ -198,13 +183,25 @@ We'll start with the following MLX implementation using standard ops:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||||
to write a fast GPU kernel for both the forward and backward passes.
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
First we'll implement the forward pass as a fused kernel:
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@mx.custom_function
|
||||||
|
def grid_sample(x, grid):
|
||||||
|
|
||||||
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
|
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
out_shape = (B, gN, gM, C)
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
int H = x_shape[1];
|
int H = x_shape[1];
|
||||||
@@ -254,26 +251,12 @@ First we'll implement the forward pass as a fused kernel:
|
|||||||
|
|
||||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="grid_sample",
|
name="grid_sample",
|
||||||
input_names=["x", "grid"],
|
input_names=["x", "grid"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
@mx.custom_function
|
|
||||||
def grid_sample(x, grid):
|
|
||||||
|
|
||||||
assert x.ndim == 4, "`x` must be 4D."
|
|
||||||
assert grid.ndim == 4, "`grid` must be 4D."
|
|
||||||
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
out_shape = (B, gN, gM, C)
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[x, grid],
|
inputs=[x, grid],
|
||||||
template=[("T", x.dtype)],
|
template=[("T", x.dtype)],
|
||||||
@@ -298,11 +281,11 @@ On an M1 Max, we see a big performance improvement:
|
|||||||
Grid Sample VJP
|
Grid Sample VJP
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||||
define its custom vjp transform so MLX can differentiate it.
|
its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
requires a few extra :func:`fast.metal_kernel` features:
|
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||||
|
|
||||||
* ``init_value=0``
|
* ``init_value=0``
|
||||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
@@ -316,6 +299,14 @@ We can then implement the backwards pass as follows:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@grid_sample.vjp
|
||||||
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
|
x, grid = primals
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
int H = x_shape[1];
|
int H = x_shape[1];
|
||||||
@@ -415,15 +406,6 @@ We can then implement the backwards pass as follows:
|
|||||||
source=source,
|
source=source,
|
||||||
atomic_outputs=True,
|
atomic_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@grid_sample.vjp
|
|
||||||
def grid_sample_vjp(primals, cotangent, _):
|
|
||||||
x, grid = primals
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
# pad the output channels to simd group size
|
# pad the output channels to simd group size
|
||||||
# so that our `simd_sum`s don't overlap.
|
# so that our `simd_sum`s don't overlap.
|
||||||
simdgroup_size = 32
|
simdgroup_size = 32
|
||||||
|
|||||||
@@ -138,13 +138,13 @@ more concrete:
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** The name of primitive. */
|
/** Print the primitive. */
|
||||||
const char* name() const override {
|
void print(std::ostream& os) override {
|
||||||
return "Axpby";
|
os << "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
@@ -394,14 +394,14 @@ below.
|
|||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel
|
// Resolve name of kernel
|
||||||
std::stream kname;
|
std::ostringstream kname;
|
||||||
kname = "axpby_general_" + type_to_name(out);
|
kname << "axpby_" << "general_" << type_to_name(out);
|
||||||
|
|
||||||
// Load the metal library
|
// Make sure the metal library is available
|
||||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
d.register_library("mlx_ext");
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname, lib);
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -70,7 +70,6 @@ are the CPU and GPU.
|
|||||||
python/fft
|
python/fft
|
||||||
python/linalg
|
python/linalg
|
||||||
python/metal
|
python/metal
|
||||||
python/cuda
|
|
||||||
python/memory_management
|
python/memory_management
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
|
|||||||
@@ -13,48 +13,22 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
To install from PyPI your system must meet the following requirements:
|
To install from PyPI you must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.10
|
- Using a native Python >= 3.9
|
||||||
- macOS >= 14.0
|
- macOS >= 13.5
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
MLX is only available on devices running macOS >= 14.0 and higher.
|
MLX is only available on devices running macOS >= 13.5
|
||||||
|
It is highly recommended to use macOS 14 (Sonoma)
|
||||||
|
|
||||||
CUDA
|
|
||||||
^^^^
|
|
||||||
|
|
||||||
MLX has a CUDA backend which you can install with:
|
MLX is also available on conda-forge. To install MLX with conda do:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install mlx[cuda]
|
conda install conda-forge::mlx
|
||||||
|
|
||||||
To install the CUDA package from PyPi your system must meet the following
|
|
||||||
requirements:
|
|
||||||
|
|
||||||
- Nvidia architecture >= SM 7.0 (Volta)
|
|
||||||
- Nvidia driver >= 550.54.14
|
|
||||||
- CUDA toolkit >= 12.0
|
|
||||||
- Linux distribution with glibc >= 2.35
|
|
||||||
- Python >= 3.10
|
|
||||||
|
|
||||||
|
|
||||||
CPU-only (Linux)
|
|
||||||
^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
For a CPU-only version of MLX that runs on Linux use:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
pip install mlx[cpu]
|
|
||||||
|
|
||||||
To install the CPU-only package from PyPi your system must meet the following
|
|
||||||
requirements:
|
|
||||||
|
|
||||||
- Linux distribution with glibc >= 2.35
|
|
||||||
- Python >= 3.10
|
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
@@ -91,8 +65,6 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
.. _python install:
|
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@@ -104,20 +76,20 @@ Then simply build and install MLX using pip:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install .
|
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
||||||
|
|
||||||
For developing, install the package with development dependencies, and use an
|
For developing, install the package with development dependencies, and use an
|
||||||
editable install:
|
editable install:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install -e ".[dev]"
|
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
||||||
|
|
||||||
Once the development dependencies are installed, you can build faster with:
|
Once the development dependencies are installed, you can build faster with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
python setup.py build_ext --inplace
|
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
||||||
|
|
||||||
Run the tests with:
|
Run the tests with:
|
||||||
|
|
||||||
@@ -135,8 +107,6 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
.. _cpp install:
|
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@@ -215,7 +185,6 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -244,50 +213,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
Linux
|
|
||||||
^^^^^
|
|
||||||
|
|
||||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
|
||||||
For example on Ubuntu, run the following:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
apt-get update -y
|
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
|
||||||
|
|
||||||
From here follow the instructions to install either the :ref:`Python <python
|
|
||||||
install>` or :ref:`C++ <cpp install>` APIs.
|
|
||||||
|
|
||||||
CUDA
|
|
||||||
^^^^
|
|
||||||
|
|
||||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
|
||||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
|
||||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
|
||||||
apt-get update -y
|
|
||||||
apt-get -y install cuda-toolkit-12-9
|
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
|
||||||
|
|
||||||
|
|
||||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
|
||||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
|
||||||
|
|
||||||
To build the C++ package run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
mkdir -p build && cd build
|
|
||||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ Array
|
|||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
array.real
|
|
||||||
array.imag
|
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
CUDA
|
|
||||||
=====
|
|
||||||
|
|
||||||
.. currentmodule:: mlx.core.cuda
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: _autosummary
|
|
||||||
|
|
||||||
is_available
|
|
||||||
@@ -13,4 +13,3 @@ Fast
|
|||||||
rope
|
rope
|
||||||
scaled_dot_product_attention
|
scaled_dot_product_attention
|
||||||
metal_kernel
|
metal_kernel
|
||||||
cuda_kernel
|
|
||||||
|
|||||||
@@ -20,5 +20,3 @@ FFT
|
|||||||
irfft2
|
irfft2
|
||||||
rfftn
|
rfftn
|
||||||
irfftn
|
irfftn
|
||||||
fftshift
|
|
||||||
ifftshift
|
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ Linear Algebra
|
|||||||
cross
|
cross
|
||||||
qr
|
qr
|
||||||
svd
|
svd
|
||||||
eigvals
|
|
||||||
eig
|
|
||||||
eigvalsh
|
eigvalsh
|
||||||
eigh
|
eigh
|
||||||
lu
|
lu
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ simple functions.
|
|||||||
mish
|
mish
|
||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
relu2
|
|
||||||
relu6
|
relu6
|
||||||
selu
|
selu
|
||||||
sigmoid
|
sigmoid
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ Layers
|
|||||||
QuantizedLinear
|
QuantizedLinear
|
||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
ReLU2
|
|
||||||
ReLU6
|
ReLU6
|
||||||
RNN
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
|
|||||||
@@ -112,7 +112,6 @@ Operations
|
|||||||
max
|
max
|
||||||
maximum
|
maximum
|
||||||
mean
|
mean
|
||||||
median
|
|
||||||
meshgrid
|
meshgrid
|
||||||
min
|
min
|
||||||
minimum
|
minimum
|
||||||
|
|||||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state, destination={})
|
state = tree_flatten(optimizer.state)
|
||||||
mx.save_safetensors("optimizer.safetensors", state)
|
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
|||||||
@@ -19,4 +19,3 @@ Common Optimizers
|
|||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
MultiOptimizer
|
MultiOptimizer
|
||||||
Muon
|
|
||||||
|
|||||||
@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||||
timeit(gelu, x)
|
timeit(nn.gelu, x)
|
||||||
timeit(mx.compile(gelu), x)
|
timeit(mx.compile(nn.gelu), x)
|
||||||
|
|
||||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||||
five times faster.
|
five times faster.
|
||||||
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
|
|||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
z = x + y
|
z = x + y
|
||||||
state.append(z)
|
state.append(z)
|
||||||
return mx.exp(z)
|
return mx.exp(z), state
|
||||||
|
|
||||||
fun(mx.array(1.0), mx.array(2.0))
|
fun(mx.array(1.0), mx.array(2.0))
|
||||||
# Prints [array(3, dtype=float32)]
|
# Prints [array(3, dtype=float32)]
|
||||||
|
|||||||
@@ -7,13 +7,12 @@ Distributed Communication
|
|||||||
|
|
||||||
MLX supports distributed communication operations that allow the computational cost
|
MLX supports distributed communication operations that allow the computational cost
|
||||||
of training or inference to be shared across many physical machines. At the
|
of training or inference to be shared across many physical machines. At the
|
||||||
moment we support three different communication backends:
|
moment we support two different communication backends:
|
||||||
|
|
||||||
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||||
full-featured and mature distributed communications library
|
full-featured and mature distributed communications library
|
||||||
* A **ring** backend of our own that uses native TCP sockets. It should be
|
* A **ring** backend of our own that uses native TCP sockets and should be
|
||||||
faster for thunderbolt connections, but it also works over Ethernet.
|
faster for thunderbolt connections.
|
||||||
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
|
|
||||||
|
|
||||||
The list of all currently supported operations and their documentation can be
|
The list of all currently supported operations and their documentation can be
|
||||||
seen in the :ref:`API docs<distributed>`.
|
seen in the :ref:`API docs<distributed>`.
|
||||||
@@ -85,8 +84,9 @@ Selecting Backend
|
|||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
You can select the backend you want to use when calling :func:`init` by passing
|
You can select the backend you want to use when calling :func:`init` by passing
|
||||||
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
|
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
||||||
available backends. If they all fail then a singleton group is created.
|
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
||||||
|
both fail then a singleton group is created.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
After a distributed backend is successfully initialized :func:`init` will
|
After a distributed backend is successfully initialized :func:`init` will
|
||||||
@@ -184,7 +184,7 @@ almost identical to the example above:
|
|||||||
|
|
||||||
def step(model, x, y):
|
def step(model, x, y):
|
||||||
loss, grads = loss_grad_fn(model, x, y)
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
grads = mx.nn.average_gradients(grads) # <---- This line was added
|
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@@ -220,7 +220,7 @@ print 4 etc.
|
|||||||
Installing MPI
|
Installing MPI
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
|
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||||
with the Anaconda package manager as follows:
|
with the Anaconda package manager as follows:
|
||||||
|
|
||||||
@@ -228,16 +228,14 @@ with the Anaconda package manager as follows:
|
|||||||
|
|
||||||
$ conda install conda-forge::openmpi
|
$ conda install conda-forge::openmpi
|
||||||
|
|
||||||
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
|
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||||
done automatically by ``mlx.launch``. Some environments use a non-standard
|
done automatically by ``mlx.launch``.
|
||||||
library filename that can be specified using the ``MPI_LIBNAME`` environment
|
|
||||||
variable. This is automatically taken care of by ``mlx.launch`` as well.
|
|
||||||
|
|
||||||
.. code:: shell
|
.. code:: shell
|
||||||
|
|
||||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
|
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||||
$ # or simply
|
$ # or simply
|
||||||
$ mlx.launch -n 2 test.py
|
$ mlx.launch -n 2 test.py
|
||||||
|
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = tree_flatten(model.parameters(), destination={})
|
params = dict(tree_flatten(model.parameters()))
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
@@ -164,11 +164,11 @@ to export a function which can be used for inputs with variable shapes:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
|
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
|
||||||
imported_abs = mx.import_function("fun.mlxfn")
|
imported_abs = mx.import_function("fun.mlxfn")
|
||||||
|
|
||||||
# Ok
|
# Ok
|
||||||
out, = imported_abs(mx.array([-1.0]))
|
out, = imported_abs(mx.array(-1.0))
|
||||||
|
|
||||||
# Also ok
|
# Also ok
|
||||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||||
|
|||||||
@@ -70,8 +70,7 @@ Differences from NumPy
|
|||||||
|
|
||||||
* Indexing does not perform bounds checking. Indexing out of bounds is
|
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||||
undefined behavior.
|
undefined behavior.
|
||||||
* Boolean mask based indexing is supported for assignment only (see
|
* Boolean mask based indexing is not yet supported.
|
||||||
:ref:`boolean-mask-assignment`).
|
|
||||||
|
|
||||||
The reason for the lack of bounds checking is that exceptions cannot propagate
|
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||||
from the GPU. Performing bounds checking for array indices before launching the
|
from the GPU. Performing bounds checking for array indices before launching the
|
||||||
@@ -108,28 +107,6 @@ same array:
|
|||||||
>>> a
|
>>> a
|
||||||
array([1, 2, 0], dtype=int32)
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
|
||||||
mutating it does not mutate the original array:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
>>> a = mx.array([1, 2, 3])
|
|
||||||
>>> b = a[:]
|
|
||||||
>>> b[2] = 0
|
|
||||||
>>> b
|
|
||||||
array([1, 2, 0], dtype=int32)
|
|
||||||
>>> a
|
|
||||||
array([1, 2, 3], dtype=int32)
|
|
||||||
|
|
||||||
Also unlike NumPy, updates to the same location are nondeterministic:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
>>> a = mx.array([1, 2, 3])
|
|
||||||
>>> a[[0, 0]] = mx.array([4, 5])
|
|
||||||
|
|
||||||
The first element of ``a`` could be ``4`` or ``5``.
|
|
||||||
|
|
||||||
Transformations of functions which use in-place updates are allowed and work as
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
expected. For example:
|
expected. For example:
|
||||||
|
|
||||||
@@ -144,51 +121,3 @@ expected. For example:
|
|||||||
|
|
||||||
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||||
and ones elsewhere.
|
and ones elsewhere.
|
||||||
|
|
||||||
.. _boolean-mask-assignment:
|
|
||||||
|
|
||||||
Boolean Mask Assignment
|
|
||||||
-----------------------
|
|
||||||
|
|
||||||
MLX supports boolean indices using NumPy syntax. A mask must already be
|
|
||||||
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
|
|
||||||
Other index types are routed through the standard scatter code.
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
>>> a = mx.array([1.0, 2.0, 3.0])
|
|
||||||
>>> mask = mx.array([True, False, True])
|
|
||||||
>>> updates = mx.array([5.0, 6.0])
|
|
||||||
>>> a[mask] = updates
|
|
||||||
>>> a
|
|
||||||
array([5.0, 2.0, 6.0], dtype=float32)
|
|
||||||
|
|
||||||
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
|
|
||||||
assignments, ``updates`` must provide at least as many elements as there are
|
|
||||||
``True`` entries in ``mask``.
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
>>> a = mx.zeros((2, 3))
|
|
||||||
>>> mask = mx.array([[True, False, True],
|
|
||||||
[False, False, True]])
|
|
||||||
>>> a[mask] = 1.0
|
|
||||||
>>> a
|
|
||||||
array([[1.0, 0.0, 1.0],
|
|
||||||
[0.0, 0.0, 1.0]], dtype=float32)
|
|
||||||
|
|
||||||
Boolean masks follow NumPy semantics:
|
|
||||||
|
|
||||||
- The mask shape must match the shape of the axes it indexes exactly. The only
|
|
||||||
exception is a scalar boolean mask, which broadcasts to the full array.
|
|
||||||
- Any axes not covered by the mask are taken in full.
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
|
||||||
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
|
||||||
|
|
||||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
|
||||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
|
||||||
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
|
|
||||||
axes and therefore raise errors.
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023-2025 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
#include <dlfcn.h>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@@ -17,19 +16,6 @@
|
|||||||
|
|
||||||
namespace my_ext {
|
namespace my_ext {
|
||||||
|
|
||||||
// A helper function to find the location of the current binary on disk.
|
|
||||||
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
|
||||||
std::string current_binary_dir() {
|
|
||||||
static std::string binary_dir = []() {
|
|
||||||
Dl_info info;
|
|
||||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
|
||||||
throw std::runtime_error("Unable to get current binary dir.");
|
|
||||||
}
|
|
||||||
return std::filesystem::path(info.dli_fname).parent_path().string();
|
|
||||||
}();
|
|
||||||
return binary_dir;
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Operation Implementation
|
// Operation Implementation
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -181,15 +167,16 @@ void Axpby::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
std::string kname = "axpby_";
|
std::ostringstream kname;
|
||||||
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
kname << "axpby_";
|
||||||
kname += type_to_name(out);
|
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
|
kname << type_to_name(out);
|
||||||
|
|
||||||
// Load the metal library
|
// Make sure the metal library is available
|
||||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
d.register_library("mlx_ext");
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname, lib);
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
|||||||
const std::vector<mx::array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** The name of primitive. */
|
/** Print the primitive. */
|
||||||
const char* name() const override {
|
void print(std::ostream& os) override {
|
||||||
return "Axpby";
|
os << "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.25
|
cmake>=3.25
|
||||||
mlx>=0.21.0
|
mlx>=0.21.0
|
||||||
nanobind==2.4.0
|
nanobind==2.2.0
|
||||||
|
|||||||
@@ -3,10 +3,8 @@ from mlx_sample_extensions import axpby
|
|||||||
|
|
||||||
a = mx.ones((3, 4))
|
a = mx.ones((3, 4))
|
||||||
b = mx.ones((3, 4))
|
b = mx.ones((3, 4))
|
||||||
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||||
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
|
||||||
|
|
||||||
print(f"c shape: {c_cpu.shape}")
|
print(f"c shape: {c.shape}")
|
||||||
print(f"c dtype: {c_cpu.dtype}")
|
print(f"c dtype: {c.dtype}")
|
||||||
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||||
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
@@ -21,7 +20,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||||
|
|
||||||
# Define MLX_VERSION only in the version.cpp file.
|
# Define MLX_VERSION only in the version.cpp file.
|
||||||
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||||
|
|
||||||
@@ -49,19 +48,5 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
|||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MLX_BUILD_CUDA)
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
|
||||||
else()
|
|
||||||
target_sources(mlx
|
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
|
||||||
else()
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class Buffer {
|
|||||||
void* ptr_;
|
void* ptr_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit Buffer(void* ptr) : ptr_(ptr) {};
|
Buffer(void* ptr) : ptr_(ptr) {};
|
||||||
|
|
||||||
// Get the raw data pointer from the buffer
|
// Get the raw data pointer from the buffer
|
||||||
void* raw_ptr();
|
void* raw_ptr();
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
|
|||||||
other.strides(),
|
other.strides(),
|
||||||
other.flags(),
|
other.flags(),
|
||||||
[](auto) {});
|
[](auto) {});
|
||||||
cpy.array_desc_->offset = other.array_desc_->offset;
|
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||||
return cpy;
|
return cpy;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
|
|||||||
|
|
||||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
array_desc_->offset = 0;
|
array_desc_->data_ptr = buffer.raw_ptr();
|
||||||
array_desc_->data_size = size();
|
array_desc_->data_size = size();
|
||||||
array_desc_->flags.contiguous = true;
|
array_desc_->flags.contiguous = true;
|
||||||
array_desc_->flags.row_contiguous = true;
|
array_desc_->flags.row_contiguous = true;
|
||||||
@@ -156,7 +156,7 @@ void array::set_data(
|
|||||||
Flags flags,
|
Flags flags,
|
||||||
Deleter d) {
|
Deleter d) {
|
||||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
array_desc_->offset = 0;
|
array_desc_->data_ptr = buffer.raw_ptr();
|
||||||
array_desc_->data_size = data_size;
|
array_desc_->data_size = data_size;
|
||||||
array_desc_->strides = std::move(strides);
|
array_desc_->strides = std::move(strides);
|
||||||
array_desc_->flags = flags;
|
array_desc_->flags = flags;
|
||||||
@@ -167,13 +167,14 @@ void array::copy_shared_buffer(
|
|||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
int64_t offset /* = 0 */) {
|
size_t offset /* = 0 */) {
|
||||||
array_desc_->data = other.array_desc_->data;
|
array_desc_->data = other.array_desc_->data;
|
||||||
array_desc_->strides = strides;
|
array_desc_->strides = strides;
|
||||||
array_desc_->flags = flags;
|
array_desc_->flags = flags;
|
||||||
array_desc_->data_size = data_size;
|
array_desc_->data_size = data_size;
|
||||||
array_desc_->offset =
|
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||||
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
|
array_desc_->data_ptr = static_cast<void*>(
|
||||||
|
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::copy_shared_buffer(const array& other) {
|
void array::copy_shared_buffer(const array& other) {
|
||||||
@@ -240,8 +241,8 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
std::vector<array> inputs)
|
std::vector<array> inputs)
|
||||||
: shape(std::move(shape)),
|
: shape(std::move(shape)),
|
||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
primitive(std::move(primitive)),
|
|
||||||
status(Status::unscheduled),
|
status(Status::unscheduled),
|
||||||
|
primitive(std::move(primitive)),
|
||||||
inputs(std::move(inputs)) {
|
inputs(std::move(inputs)) {
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
|
|||||||
42
mlx/array.h
42
mlx/array.h
@@ -10,7 +10,6 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
#include "mlx/small_vector.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -19,8 +18,8 @@ class Primitive;
|
|||||||
|
|
||||||
using Deleter = std::function<void(allocator::Buffer)>;
|
using Deleter = std::function<void(allocator::Buffer)>;
|
||||||
using ShapeElem = int32_t;
|
using ShapeElem = int32_t;
|
||||||
using Shape = SmallVector<ShapeElem>;
|
using Shape = std::vector<ShapeElem>;
|
||||||
using Strides = SmallVector<int64_t>;
|
using Strides = std::vector<int64_t>;
|
||||||
|
|
||||||
class array {
|
class array {
|
||||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||||
@@ -225,10 +224,6 @@ class array {
|
|||||||
// Not copyable
|
// Not copyable
|
||||||
Data(const Data& d) = delete;
|
Data(const Data& d) = delete;
|
||||||
Data& operator=(const Data& d) = delete;
|
Data& operator=(const Data& d) = delete;
|
||||||
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
|
||||||
o.buffer = allocator::Buffer(nullptr);
|
|
||||||
o.d = [](allocator::Buffer) {};
|
|
||||||
}
|
|
||||||
~Data() {
|
~Data() {
|
||||||
d(buffer);
|
d(buffer);
|
||||||
}
|
}
|
||||||
@@ -294,11 +289,6 @@ class array {
|
|||||||
return array_desc_->siblings;
|
return array_desc_->siblings;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The array's position in the sibling list. */
|
|
||||||
int sibling_position() const {
|
|
||||||
return array_desc_->position;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||||
array_desc_->siblings = std::move(siblings);
|
array_desc_->siblings = std::move(siblings);
|
||||||
array_desc_->position = position;
|
array_desc_->position = position;
|
||||||
@@ -349,32 +339,24 @@ class array {
|
|||||||
return allocator::allocator().size(buffer());
|
return allocator::allocator().size(buffer());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the shared pointer to the array::Data struct
|
// Return a copy of the shared pointer
|
||||||
const std::shared_ptr<Data>& data_shared_ptr() const {
|
// to the array::Data struct
|
||||||
|
std::shared_ptr<Data> data_shared_ptr() const {
|
||||||
return array_desc_->data;
|
return array_desc_->data;
|
||||||
}
|
}
|
||||||
|
// Return a raw pointer to the arrays data
|
||||||
// Return a raw pointer to the arrays data. This function may do a copy if
|
|
||||||
// the underlying buffer is not accessible on the CPU. When accessing the
|
|
||||||
// data for GPU kernels, be sure to use the correct method / function for the
|
|
||||||
// given backend to access the GPU pointer.
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
return reinterpret_cast<T*>(
|
return static_cast<T*>(array_desc_->data_ptr);
|
||||||
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const T* data() const {
|
const T* data() const {
|
||||||
return const_cast<array&>(*this).data<T>();
|
return static_cast<T*>(array_desc_->data_ptr);
|
||||||
}
|
|
||||||
|
|
||||||
int64_t offset() const {
|
|
||||||
return array_desc_->offset;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Status {
|
enum Status {
|
||||||
// The output of a computation which has not been scheduled.
|
// The ouptut of a computation which has not been scheduled.
|
||||||
// For example, the status of `x` in `auto x = a + b`.
|
// For example, the status of `x` in `auto x = a + b`.
|
||||||
unscheduled,
|
unscheduled,
|
||||||
|
|
||||||
@@ -439,7 +421,7 @@ class array {
|
|||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
int64_t offset = 0);
|
size_t offset = 0);
|
||||||
|
|
||||||
void copy_shared_buffer(const array& other);
|
void copy_shared_buffer(const array& other);
|
||||||
|
|
||||||
@@ -474,8 +456,8 @@ class array {
|
|||||||
// can share the underlying data buffer.
|
// can share the underlying data buffer.
|
||||||
std::shared_ptr<Data> data;
|
std::shared_ptr<Data> data;
|
||||||
|
|
||||||
// Offset from beginning of data pointer
|
// Properly offset data pointer
|
||||||
int64_t offset{0};
|
void* data_ptr{nullptr};
|
||||||
|
|
||||||
// The size in elements of the data buffer the array accesses
|
// The size in elements of the data buffer the array accesses
|
||||||
size_t data_size;
|
size_t data_size;
|
||||||
|
|||||||
@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
array& out,
|
array& out,
|
||||||
BinaryOpType bopt,
|
BinaryOpType bopt) {
|
||||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
|
||||||
bool b_donatable = is_donatable(b, out);
|
bool b_donatable = is_donatable(b, out);
|
||||||
bool a_donatable = is_donatable(a, out);
|
bool a_donatable = is_donatable(a, out);
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
|
out.set_data(
|
||||||
|
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
if (b_donatable) {
|
if (b_donatable) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
mallocfn(b.data_size() * out.itemsize()),
|
allocator::malloc(b.data_size() * out.itemsize()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
mallocfn(a.data_size() * out.itemsize()),
|
allocator::malloc(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
mallocfn(a.data_size() * out.itemsize()),
|
allocator::malloc(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
|||||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(mallocfn(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
void broadcast(const array& in, array& out) {
|
void broadcast(const array& in, array& out) {
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
out.set_data(allocator::malloc(0));
|
out.set_data(nullptr);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Strides strides(out.ndim(), 0);
|
Strides strides(out.ndim(), 0);
|
||||||
|
|||||||
@@ -1,157 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <functional>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class BufferCache {
|
|
||||||
public:
|
|
||||||
BufferCache(
|
|
||||||
size_t page_size,
|
|
||||||
std::function<size_t(T*)> get_size,
|
|
||||||
std::function<void(T*)> free)
|
|
||||||
: page_size_(page_size),
|
|
||||||
get_size_(std::move(get_size)),
|
|
||||||
free_(std::move(free)) {}
|
|
||||||
|
|
||||||
~BufferCache() {
|
|
||||||
clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferCache(const BufferCache&) = delete;
|
|
||||||
BufferCache& operator=(const BufferCache&) = delete;
|
|
||||||
|
|
||||||
T* reuse_from_cache(size_t size) {
|
|
||||||
// Find the closest buffer in pool.
|
|
||||||
auto it = buffer_pool_.lower_bound(size);
|
|
||||||
if (it == buffer_pool_.end() ||
|
|
||||||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect from the cache.
|
|
||||||
T* buf = it->second->buf;
|
|
||||||
pool_size_ -= it->first;
|
|
||||||
|
|
||||||
// Remove from record.
|
|
||||||
remove_from_list(it->second);
|
|
||||||
buffer_pool_.erase(it);
|
|
||||||
return buf;
|
|
||||||
}
|
|
||||||
|
|
||||||
void recycle_to_cache(T* buf) {
|
|
||||||
assert(buf);
|
|
||||||
// Add to cache.
|
|
||||||
BufferHolder* bh = new BufferHolder(buf);
|
|
||||||
add_at_head(bh);
|
|
||||||
size_t size = get_size_(buf);
|
|
||||||
pool_size_ += size;
|
|
||||||
buffer_pool_.emplace(size, bh);
|
|
||||||
}
|
|
||||||
|
|
||||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
|
||||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
|
||||||
return clear();
|
|
||||||
} else {
|
|
||||||
int n_release = 0;
|
|
||||||
size_t total_bytes_freed = 0;
|
|
||||||
|
|
||||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
|
||||||
// Release buffer.
|
|
||||||
size_t size = get_size_(tail_->buf);
|
|
||||||
total_bytes_freed += size;
|
|
||||||
free_(tail_->buf);
|
|
||||||
n_release++;
|
|
||||||
|
|
||||||
// Remove from record.
|
|
||||||
auto its = buffer_pool_.equal_range(size);
|
|
||||||
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
|
||||||
return el.second == tail_;
|
|
||||||
});
|
|
||||||
assert(it != buffer_pool_.end());
|
|
||||||
buffer_pool_.erase(it);
|
|
||||||
remove_from_list(tail_);
|
|
||||||
}
|
|
||||||
|
|
||||||
pool_size_ -= total_bytes_freed;
|
|
||||||
return n_release;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int clear() {
|
|
||||||
int n_release = 0;
|
|
||||||
for (auto& [size, holder] : buffer_pool_) {
|
|
||||||
free_(holder->buf);
|
|
||||||
n_release++;
|
|
||||||
delete holder;
|
|
||||||
}
|
|
||||||
buffer_pool_.clear();
|
|
||||||
pool_size_ = 0;
|
|
||||||
head_ = nullptr;
|
|
||||||
tail_ = nullptr;
|
|
||||||
return n_release;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t cache_size() const {
|
|
||||||
return pool_size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t page_size() const {
|
|
||||||
return page_size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct BufferHolder {
|
|
||||||
public:
|
|
||||||
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
|
||||||
|
|
||||||
BufferHolder* prev{nullptr};
|
|
||||||
BufferHolder* next{nullptr};
|
|
||||||
T* buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
void add_at_head(BufferHolder* to_add) {
|
|
||||||
if (!head_) {
|
|
||||||
head_ = to_add;
|
|
||||||
tail_ = to_add;
|
|
||||||
} else {
|
|
||||||
head_->prev = to_add;
|
|
||||||
to_add->next = head_;
|
|
||||||
head_ = to_add;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void remove_from_list(BufferHolder* to_remove) {
|
|
||||||
if (to_remove->prev && to_remove->next) { // if middle
|
|
||||||
to_remove->prev->next = to_remove->next;
|
|
||||||
to_remove->next->prev = to_remove->prev;
|
|
||||||
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
|
||||||
tail_ = to_remove->prev;
|
|
||||||
tail_->next = nullptr;
|
|
||||||
} else if (to_remove == head_ && to_remove->next) { // if head
|
|
||||||
head_ = to_remove->next;
|
|
||||||
head_->prev = nullptr;
|
|
||||||
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
|
||||||
head_ = nullptr;
|
|
||||||
tail_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
delete to_remove;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
|
||||||
BufferHolder* head_{nullptr};
|
|
||||||
BufferHolder* tail_{nullptr};
|
|
||||||
size_t pool_size_{0};
|
|
||||||
|
|
||||||
const size_t page_size_;
|
|
||||||
std::function<size_t(T*)> get_size_;
|
|
||||||
std::function<void(T*)> free_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -14,8 +15,6 @@ void print_constant(std::ostream& os, const array& x) {
|
|||||||
return print_float_constant<float16_t>(os, x);
|
return print_float_constant<float16_t>(os, x);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return print_float_constant<bfloat16_t>(os, x);
|
return print_float_constant<bfloat16_t>(os, x);
|
||||||
case float64:
|
|
||||||
return print_float_constant<double>(os, x);
|
|
||||||
case complex64:
|
case complex64:
|
||||||
return print_complex_constant<complex64_t>(os, x);
|
return print_complex_constant<complex64_t>(os, x);
|
||||||
case int8:
|
case int8:
|
||||||
@@ -52,8 +51,6 @@ std::string get_type_string(Dtype d) {
|
|||||||
return "float16_t";
|
return "float16_t";
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return "bfloat16_t";
|
return "bfloat16_t";
|
||||||
case float64:
|
|
||||||
return "double";
|
|
||||||
case complex64:
|
case complex64:
|
||||||
return "complex64_t";
|
return "complex64_t";
|
||||||
case bool_:
|
case bool_:
|
||||||
@@ -82,6 +79,55 @@ std::string get_type_string(Dtype d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string build_lib_name(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& tape,
|
||||||
|
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||||
|
NodeNamer namer;
|
||||||
|
std::ostringstream os;
|
||||||
|
std::ostringstream constant_hasher;
|
||||||
|
|
||||||
|
// Fill the input names. This is not really necessary, I just like having A,
|
||||||
|
// B, C, ... as the inputs.
|
||||||
|
for (auto& x : inputs) {
|
||||||
|
namer.get_name(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The primitives describing the tape. For unary and binary primitives this
|
||||||
|
// must be enough to describe the full computation.
|
||||||
|
for (auto& a : tape) {
|
||||||
|
// name and type of output
|
||||||
|
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||||
|
// computation performed
|
||||||
|
a.primitive().print(os);
|
||||||
|
// name of inputs to the function
|
||||||
|
for (auto& inp : a.inputs()) {
|
||||||
|
os << namer.get_name(inp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << "_";
|
||||||
|
|
||||||
|
for (auto& x : inputs) {
|
||||||
|
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||||
|
os << "C";
|
||||||
|
print_constant(constant_hasher, x);
|
||||||
|
} else {
|
||||||
|
os << (is_scalar(x) ? "S" : "V");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << "_";
|
||||||
|
for (auto& x : inputs) {
|
||||||
|
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os << kindof(x.dtype()) << x.itemsize();
|
||||||
|
}
|
||||||
|
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
bool compiled_check_contiguity(
|
bool compiled_check_contiguity(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
@@ -113,10 +159,9 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::function<bool(size_t)>& is_constant,
|
const std::vector<array>& inputs_,
|
||||||
bool contiguous,
|
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||||
const std::function<allocator::Buffer(size_t)>&
|
bool contiguous) {
|
||||||
mallocfn /* = allocator::malloc */) {
|
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int o = 0;
|
int o = 0;
|
||||||
Strides strides;
|
Strides strides;
|
||||||
@@ -130,7 +175,8 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() && is_constant(i)) {
|
in.is_donatable() &&
|
||||||
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
@@ -142,7 +188,7 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(
|
outputs[o].set_data(
|
||||||
mallocfn(data_size * outputs[o].itemsize()),
|
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||||
data_size,
|
data_size,
|
||||||
strides,
|
strides,
|
||||||
flags);
|
flags);
|
||||||
@@ -158,86 +204,16 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
is_constant(i)) {
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
o++;
|
o++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
|
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const array& out,
|
|
||||||
const std::function<bool(size_t)>& is_constant) {
|
|
||||||
const Shape& shape = out.shape();
|
|
||||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
|
||||||
if (contiguous) {
|
|
||||||
return {true, shape, {}};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Strides> strides_vec{out.strides()};
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
||||||
// Skip constants.
|
|
||||||
if (is_constant(i)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip scalar inputs.
|
|
||||||
const auto& x = inputs[i];
|
|
||||||
if (is_scalar(x)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast the inputs to the output shape.
|
|
||||||
Strides xstrides;
|
|
||||||
size_t j = 0;
|
|
||||||
for (; j < shape.size() - x.ndim(); ++j) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(out.strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(out.strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(x.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
strides_vec.push_back(std::move(xstrides));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
|
||||||
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
|
||||||
}
|
|
||||||
|
|
||||||
bool compiled_use_large_index(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
bool contiguous) {
|
|
||||||
if (contiguous) {
|
|
||||||
size_t max_size = 0;
|
|
||||||
for (const auto& in : inputs) {
|
|
||||||
max_size = std::max(max_size, in.data_size());
|
|
||||||
}
|
|
||||||
return max_size > UINT32_MAX;
|
|
||||||
} else {
|
|
||||||
size_t max_size = 0;
|
|
||||||
for (const auto& o : outputs) {
|
|
||||||
max_size = std::max(max_size, o.size());
|
|
||||||
}
|
|
||||||
return max_size > UINT32_MAX;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -13,17 +14,19 @@ inline bool is_static_cast(const Primitive& p) {
|
|||||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string build_lib_name(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& tape,
|
||||||
|
const std::unordered_set<uintptr_t>& constant_ids);
|
||||||
|
|
||||||
std::string get_type_string(Dtype d);
|
std::string get_type_string(Dtype d);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void print_float_constant(std::ostream& os, const array& x) {
|
void print_float_constant(std::ostream& os, const array& x) {
|
||||||
auto old_precision = os.precision();
|
auto old_precision = os.precision();
|
||||||
if constexpr (std::is_same_v<T, double>) {
|
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||||
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
<< x.item<T>() << std::setprecision(old_precision);
|
||||||
} else {
|
|
||||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
|
||||||
}
|
|
||||||
os << x.item<T>() << std::setprecision(old_precision);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -57,21 +60,8 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::function<bool(size_t)>& is_constant,
|
const std::vector<array>& inputs_,
|
||||||
bool contiguous,
|
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||||
const std::function<allocator::Buffer(size_t)>& mallocfn =
|
|
||||||
allocator::malloc);
|
|
||||||
|
|
||||||
// Collapse contiguous dims ignoring scalars and constants.
|
|
||||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const array& out,
|
|
||||||
const std::function<bool(size_t)>& is_constant);
|
|
||||||
|
|
||||||
// Return whether the kernel should use large index.
|
|
||||||
bool compiled_use_large_index(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
bool contiguous);
|
bool contiguous);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -22,27 +22,23 @@ enum class CopyType {
|
|||||||
GeneralGeneral
|
GeneralGeneral
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool set_copy_output_data(
|
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
CopyType ctype,
|
|
||||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
|
||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
// If the input is donateable, we are doing a vector copy and the types
|
// If the input is donateable, we are doing a vector copy and the types
|
||||||
// have the same size, then the input buffer can hold the output.
|
// have the same size, then the input buffer can hold the output.
|
||||||
if (is_donatable(in, out)) {
|
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
mallocfn(in.data_size() * out.itemsize()),
|
allocator::malloc(in.data_size() * out.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(mallocfn(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,10 +99,6 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
|||||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (n > (1 << 26)) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
|
||||||
}
|
|
||||||
return {n, m};
|
return {n, m};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,67 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
|
||||||
const array& a,
|
|
||||||
const array& b) {
|
|
||||||
if (a.ndim() == 2) {
|
|
||||||
return {Shape{1}, Strides{0}, Strides{0}};
|
|
||||||
}
|
|
||||||
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
|
||||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
|
||||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
|
||||||
|
|
||||||
auto [batch_shape, batch_strides] =
|
|
||||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
|
||||||
|
|
||||||
auto a_batch_strides = batch_strides[0];
|
|
||||||
auto b_batch_strides = batch_strides[1];
|
|
||||||
|
|
||||||
if (batch_shape.empty()) {
|
|
||||||
batch_shape.push_back(1);
|
|
||||||
a_batch_strides.push_back(0);
|
|
||||||
b_batch_strides.push_back(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
|
||||||
collapse_batches(const array& a, const array& b, const array& c) {
|
|
||||||
if (a.ndim() == 2) {
|
|
||||||
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
|
|
||||||
}
|
|
||||||
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
|
||||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
|
||||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
|
||||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
|
||||||
|
|
||||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
|
||||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
|
||||||
|
|
||||||
auto A_batch_stride = batch_strides[0];
|
|
||||||
auto B_batch_stride = batch_strides[1];
|
|
||||||
auto C_batch_stride = batch_strides[2];
|
|
||||||
|
|
||||||
if (batch_shape.empty()) {
|
|
||||||
batch_shape.push_back(1);
|
|
||||||
A_batch_stride.push_back(0);
|
|
||||||
B_batch_stride.push_back(0);
|
|
||||||
C_batch_stride.push_back(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(
|
|
||||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -5,9 +5,11 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
Shape shape,
|
const array& x,
|
||||||
Strides strides,
|
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
auto shape = x.shape();
|
||||||
|
auto strides = x.strides();
|
||||||
|
|
||||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||||
int a = axes[i];
|
int a = axes[i];
|
||||||
shape.erase(shape.begin() + a);
|
shape.erase(shape.begin() + a);
|
||||||
@@ -17,15 +19,6 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|||||||
return std::make_pair(shape, strides);
|
return std::make_pair(shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|
||||||
const array& x,
|
|
||||||
const std::vector<int>& axes) {
|
|
||||||
auto shape = x.shape();
|
|
||||||
auto strides = x.strides();
|
|
||||||
return shapes_without_reduction_axes(
|
|
||||||
std::move(shape), std::move(strides), axes);
|
|
||||||
}
|
|
||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
|
|||||||
@@ -51,9 +51,5 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
|||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
const array& x,
|
||||||
const std::vector<int>& axes);
|
const std::vector<int>& axes);
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|
||||||
Shape shape,
|
|
||||||
Strides strides,
|
|
||||||
const std::vector<int>& axes);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -14,13 +14,17 @@ std::tuple<int64_t, Strides> prepare_slice(
|
|||||||
data_offset += start_indices[i] * in.strides()[i];
|
data_offset += start_indices[i] * in.strides()[i];
|
||||||
inp_strides[i] = in.strides()[i] * strides[i];
|
inp_strides[i] = in.strides()[i] * strides[i];
|
||||||
}
|
}
|
||||||
|
// Normalize the offset
|
||||||
|
if (data_offset < 0) {
|
||||||
|
data_offset += in.data_size();
|
||||||
|
}
|
||||||
return std::make_tuple(data_offset, inp_strides);
|
return std::make_tuple(data_offset, inp_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
void shared_buffer_slice(
|
void shared_buffer_slice(
|
||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
int64_t data_offset,
|
size_t data_offset,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
array& out) {
|
array& out) {
|
||||||
// Compute row/col contiguity
|
// Compute row/col contiguity
|
||||||
@@ -41,30 +45,23 @@ void slice(
|
|||||||
const Shape& start_indices,
|
const Shape& start_indices,
|
||||||
const Shape& strides) {
|
const Shape& strides) {
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
out.set_data(allocator::malloc(0));
|
out.set_data(nullptr);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate out strides, initial offset
|
// Calculate out strides, initial offset
|
||||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||||
|
int64_t data_end = 1;
|
||||||
// Get the location of the end based on the inp strides and out.shape()
|
for (int i = 0; i < start_indices.size(); ++i) {
|
||||||
int64_t low_idx = 0;
|
if (in.shape()[i] > 1) {
|
||||||
int64_t high_idx = 0;
|
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||||
for (int i = 0; i < inp_strides.size(); ++i) {
|
data_end += end_idx * in.strides()[i];
|
||||||
auto delta = inp_strides[i] * (out.shape()[i] - 1);
|
|
||||||
if (inp_strides[i] > 0) {
|
|
||||||
high_idx += delta;
|
|
||||||
} else {
|
|
||||||
low_idx += delta;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int64_t data_size = (high_idx - low_idx) + 1;
|
if (data_end < 0) {
|
||||||
if (data_size < 0) {
|
data_end += in.data_size();
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[slice] Computed invalid data size: " << data_size << ".";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
}
|
||||||
|
size_t data_size = (data_end - data_offset);
|
||||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,6 @@ namespace mlx::core {
|
|||||||
enum class TernaryOpType {
|
enum class TernaryOpType {
|
||||||
ScalarScalarScalar,
|
ScalarScalarScalar,
|
||||||
VectorVectorVector,
|
VectorVectorVector,
|
||||||
VectorVectorScalar,
|
|
||||||
VectorScalarVector,
|
|
||||||
General,
|
General,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -27,14 +25,6 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
|||||||
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
||||||
c.flags().col_contiguous)) {
|
c.flags().col_contiguous)) {
|
||||||
topt = TernaryOpType::VectorVectorVector;
|
topt = TernaryOpType::VectorVectorVector;
|
||||||
} else if (
|
|
||||||
b.data_size() == 1 && a.flags().row_contiguous &&
|
|
||||||
c.flags().row_contiguous) {
|
|
||||||
topt = TernaryOpType::VectorScalarVector;
|
|
||||||
} else if (
|
|
||||||
c.data_size() == 1 && a.flags().row_contiguous &&
|
|
||||||
b.flags().row_contiguous) {
|
|
||||||
topt = TernaryOpType::VectorVectorScalar;
|
|
||||||
} else {
|
} else {
|
||||||
topt = TernaryOpType::General;
|
topt = TernaryOpType::General;
|
||||||
}
|
}
|
||||||
@@ -46,8 +36,7 @@ inline void set_ternary_op_output_data(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
array& out,
|
array& out,
|
||||||
TernaryOpType topt,
|
TernaryOpType topt) {
|
||||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
|
||||||
auto maybe_donate = [&out](const array& x) {
|
auto maybe_donate = [&out](const array& x) {
|
||||||
if (is_donatable(x, out)) {
|
if (is_donatable(x, out)) {
|
||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
@@ -58,25 +47,24 @@ inline void set_ternary_op_output_data(
|
|||||||
|
|
||||||
switch (topt) {
|
switch (topt) {
|
||||||
case TernaryOpType::ScalarScalarScalar:
|
case TernaryOpType::ScalarScalarScalar:
|
||||||
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
|
out.set_data(
|
||||||
|
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
||||||
break;
|
break;
|
||||||
case TernaryOpType::VectorVectorVector:
|
case TernaryOpType::VectorVectorVector:
|
||||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
mallocfn(out.itemsize() * b.data_size()),
|
allocator::malloc(out.itemsize() * b.data_size()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case TernaryOpType::VectorVectorScalar:
|
|
||||||
case TernaryOpType::VectorScalarVector:
|
|
||||||
case TernaryOpType::General:
|
case TernaryOpType::General:
|
||||||
// Try to donate an input which is row_contiguous
|
// Try to donate an input which is row_contiguous
|
||||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||||
out.set_data(mallocfn(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
inline void set_unary_output_data(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
|
||||||
if (in.flags().contiguous) {
|
|
||||||
if (is_donatable(in, out)) {
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
} else {
|
|
||||||
out.set_data(
|
|
||||||
mallocfn(in.data_size() * out.itemsize()),
|
|
||||||
in.data_size(),
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
out.set_data(mallocfn(out.nbytes()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,22 +1,9 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <dlfcn.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::filesystem::path current_binary_dir() {
|
|
||||||
static std::filesystem::path binary_dir = []() {
|
|
||||||
Dl_info info;
|
|
||||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
|
||||||
throw std::runtime_error("Unable to get current binary dir.");
|
|
||||||
}
|
|
||||||
return std::filesystem::path(info.dli_fname).parent_path();
|
|
||||||
}();
|
|
||||||
return binary_dir;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const std::vector<Strides>& strides,
|
const std::vector<Strides>& strides,
|
||||||
@@ -114,118 +101,4 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||||
}
|
}
|
||||||
|
|
||||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
|
||||||
int pows[3] = {0, 0, 0};
|
|
||||||
int sum = 0;
|
|
||||||
while (true) {
|
|
||||||
int presum = sum;
|
|
||||||
// Check all the pows
|
|
||||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
|
||||||
pows[0]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == 10) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
|
||||||
pows[1]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == 10) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
|
||||||
pows[2]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == presum || sum == pow2) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
|
||||||
// Dims with strides of 0 are ignored as they
|
|
||||||
// correspond to broadcasted dimensions
|
|
||||||
size_t grid_x = 1;
|
|
||||||
size_t grid_y = 1;
|
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
|
||||||
if (strides[i] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (grid_x * shape[i] < UINT32_MAX) {
|
|
||||||
grid_x *= shape[i];
|
|
||||||
} else {
|
|
||||||
grid_y *= shape[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
|
||||||
}
|
|
||||||
if (grid_y > grid_x) {
|
|
||||||
std::swap(grid_x, grid_y);
|
|
||||||
}
|
|
||||||
return std::make_tuple(
|
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
Dims get_2d_grid_dims_common(
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides,
|
|
||||||
size_t divisor) {
|
|
||||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
|
||||||
// divided by divisor.
|
|
||||||
size_t grid_x = 1;
|
|
||||||
size_t grid_y = 1;
|
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
|
||||||
if (strides[i] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// No need to add this shape we can just remove it from the divisor.
|
|
||||||
if (divisor % shape[i] == 0) {
|
|
||||||
divisor /= shape[i];
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (grid_x * shape[i] < UINT32_MAX) {
|
|
||||||
grid_x *= shape[i];
|
|
||||||
} else {
|
|
||||||
grid_y *= shape[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (divisor > 1) {
|
|
||||||
if (grid_x % divisor == 0) {
|
|
||||||
grid_x /= divisor;
|
|
||||||
divisor = 1;
|
|
||||||
} else if (grid_y % divisor == 0) {
|
|
||||||
grid_y /= divisor;
|
|
||||||
divisor = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
|
||||||
}
|
|
||||||
if (grid_y > grid_x) {
|
|
||||||
std::swap(grid_x, grid_y);
|
|
||||||
}
|
|
||||||
if (divisor > 1) {
|
|
||||||
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
|
|
||||||
}
|
|
||||||
return std::make_tuple(
|
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
|
||||||
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
|
||||||
auto gx = (dim0 + bx - 1) / bx;
|
|
||||||
auto gy = (dim1 + by - 1) / by;
|
|
||||||
auto gz = (dim2 + bz - 1) / bz;
|
|
||||||
|
|
||||||
return std::make_pair(
|
|
||||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,17 +2,12 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <filesystem>
|
|
||||||
#include <tuple>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Return the directory that contains current shared library.
|
|
||||||
std::filesystem::path current_binary_dir();
|
|
||||||
|
|
||||||
inline int64_t
|
inline int64_t
|
||||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||||
int64_t loc = 0;
|
int64_t loc = 0;
|
||||||
@@ -75,31 +70,6 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||||
|
|
||||||
// Compute the thread block dimensions which fit the given
|
|
||||||
// input dimensions.
|
|
||||||
// - The thread block dimensions will be powers of two
|
|
||||||
// - The thread block size will be less than 2^pow2
|
|
||||||
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
|
||||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
|
||||||
|
|
||||||
// Computes a 2D grid where each element is < UINT_MAX
|
|
||||||
// Assumes:
|
|
||||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
|
||||||
// - shape and strides correspond to a contiguous (no holes) but
|
|
||||||
// possibly broadcasted array
|
|
||||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
|
||||||
|
|
||||||
// Same as above but we do an implicit division with divisor.
|
|
||||||
// Basically, equivalent to factorizing
|
|
||||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
|
||||||
Dims get_2d_grid_dims_common(
|
|
||||||
const Shape& shape,
|
|
||||||
const Strides& strides,
|
|
||||||
size_t divisor);
|
|
||||||
|
|
||||||
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
|
||||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
|
||||||
|
|
||||||
struct ContiguousIterator {
|
struct ContiguousIterator {
|
||||||
inline void step() {
|
inline void step() {
|
||||||
int dims = shape_.size();
|
int dims = shape_.size();
|
||||||
@@ -195,11 +165,4 @@ void shared_buffer_reshape(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
|
||||||
vec.erase(std::next(vec.begin(), index));
|
|
||||||
return vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -40,13 +40,11 @@ add_dependencies(mlx cpu_compiled_preamble)
|
|||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
|
|||||||
@@ -14,8 +14,10 @@ template <typename InT, typename OpT>
|
|||||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||||
auto axis_size = in.shape()[axis];
|
auto axis_size = in.shape()[axis];
|
||||||
auto axis_stride = in.strides()[axis];
|
auto axis_stride = in.strides()[axis];
|
||||||
Strides strides = remove_index(in.strides(), axis);
|
Strides strides = in.strides();
|
||||||
Shape shape = remove_index(in.shape(), axis);
|
Shape shape = in.shape();
|
||||||
|
strides.erase(strides.begin() + axis);
|
||||||
|
shape.erase(shape.begin() + axis);
|
||||||
auto in_ptr = in.data<InT>();
|
auto in_ptr = in.data<InT>();
|
||||||
auto out_ptr = out.data<uint32_t>();
|
auto out_ptr = out.data<uint32_t>();
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cpu/available.h"
|
|
||||||
|
|
||||||
namespace mlx::core::cpu {
|
|
||||||
|
|
||||||
bool is_available() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cpu
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
namespace mlx::core::cpu {
|
|
||||||
|
|
||||||
bool is_available();
|
|
||||||
|
|
||||||
} // namespace mlx::core::cpu
|
|
||||||
@@ -14,11 +14,230 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void comparison_op(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_float(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[binary_float] Only supports non-complex floating point types.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_int(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, Op>(a, b, out, bopt);
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("[binary_int] Type not supported");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_op_cpu(a, b, out, detail::Add(), stream());
|
binary(a, b, out, detail::Add(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void DivMod::eval_cpu(
|
void DivMod::eval_cpu(
|
||||||
@@ -102,14 +321,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_op_cpu(a, b, out, detail::Divide(), stream());
|
binary(a, b, out, detail::Divide(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_op_cpu(a, b, out, detail::Remainder(), stream());
|
binary(a, b, out, detail::Remainder(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -150,90 +369,89 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
comparison_op_cpu(a, b, out, detail::Equal(), stream());
|
comparison_op(a, b, out, detail::Equal(), stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
|
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op_cpu(
|
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||||
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
|
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
|
binary_float(a, b, out, detail::LogAddExp(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
|
binary(in1, in2, out, detail::LogicalAnd(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
|
binary(in1, in2, out, detail::LogicalOr(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_op_cpu(a, b, out, detail::Maximum(), stream());
|
binary(a, b, out, detail::Maximum(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_op_cpu(a, b, out, detail::Minimum(), stream());
|
binary(a, b, out, detail::Minimum(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_op_cpu(a, b, out, detail::Multiply(), stream());
|
binary(a, b, out, detail::Multiply(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_op_cpu(a, b, out, detail::Power(), stream());
|
binary(a, b, out, detail::Power(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_op_cpu(a, b, out, detail::Subtract(), stream());
|
binary(a, b, out, detail::Subtract(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -242,19 +460,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
switch (op_) {
|
switch (op_) {
|
||||||
case BitwiseBinary::And:
|
case BitwiseBinary::And:
|
||||||
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
|
binary_int(a, b, out, detail::BitwiseAnd(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Or:
|
case BitwiseBinary::Or:
|
||||||
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
|
binary_int(a, b, out, detail::BitwiseOr(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Xor:
|
case BitwiseBinary::Xor:
|
||||||
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
|
binary_int(a, b, out, detail::BitwiseXor(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::LeftShift:
|
case BitwiseBinary::LeftShift:
|
||||||
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
|
binary_int(a, b, out, detail::LeftShift(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::RightShift:
|
case BitwiseBinary::RightShift:
|
||||||
binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
|
binary_int(a, b, out, detail::RightShift(), stream());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -263,7 +481,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
const auto& b = inputs[1];
|
const auto& b = inputs[1];
|
||||||
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
|
binary_float(a, b, out, detail::ArcTan2(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -291,227 +290,4 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
|||||||
binary_op<T, T, Op>(a, b, out, bopt);
|
binary_op<T, T, Op>(a, b, out, bopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_op_cpu(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void comparison_op_cpu(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (a.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_float_op_cpu(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[binary_float] Only supports floating point types.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_int_op_cpu(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool, Op>(a, b, out, bopt);
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error("[binary_int] Type not supported");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
|||||||
|
|
||||||
// The decomposition is computed in place, so just copy the input to the
|
// The decomposition is computed in place, so just copy the input to the
|
||||||
// output.
|
// output.
|
||||||
copy_cpu(
|
copy(
|
||||||
a,
|
a,
|
||||||
factor,
|
factor,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
#include "mlx/backend/cpu/jit_compiler.h"
|
#include "mlx/backend/cpu/jit_compiler.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
#include "mlx/version.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -41,10 +40,7 @@ struct CompilerCache {
|
|||||||
std::shared_mutex mtx;
|
std::shared_mutex mtx;
|
||||||
};
|
};
|
||||||
|
|
||||||
static CompilerCache& cache() {
|
static CompilerCache cache{};
|
||||||
static CompilerCache cache_;
|
|
||||||
return cache_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// GPU compile is always available if the GPU is available and since we are in
|
// GPU compile is always available if the GPU is available and since we are in
|
||||||
// this file CPU compile is also available.
|
// this file CPU compile is also available.
|
||||||
@@ -60,16 +56,14 @@ void* compile(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::function<std::string(void)>& source_builder) {
|
const std::function<std::string(void)>& source_builder) {
|
||||||
{
|
{
|
||||||
std::shared_lock lock(cache().mtx);
|
std::shared_lock lock(cache.mtx);
|
||||||
if (auto it = cache().kernels.find(kernel_name);
|
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||||
it != cache().kernels.end()) {
|
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_lock lock(cache().mtx);
|
std::unique_lock lock(cache.mtx);
|
||||||
if (auto it = cache().kernels.find(kernel_name);
|
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||||
it != cache().kernels.end()) {
|
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::string source_code = source_builder();
|
std::string source_code = source_builder();
|
||||||
@@ -95,11 +89,7 @@ void* compile(
|
|||||||
kernel_file_name = kernel_name;
|
kernel_file_name = kernel_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto output_dir =
|
auto output_dir = std::filesystem::temp_directory_path();
|
||||||
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
|
|
||||||
if (!std::filesystem::exists(output_dir)) {
|
|
||||||
std::filesystem::create_directories(output_dir);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
|
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
|
||||||
auto shared_lib_path = (output_dir / shared_lib_name).string();
|
auto shared_lib_path = (output_dir / shared_lib_name).string();
|
||||||
@@ -130,10 +120,10 @@ void* compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
cache().libs.emplace_back(shared_lib_path);
|
cache.libs.emplace_back(shared_lib_path);
|
||||||
|
|
||||||
// Load function
|
// Load function
|
||||||
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||||
if (!fun) {
|
if (!fun) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||||
@@ -141,7 +131,7 @@ void* compile(
|
|||||||
<< dlerror();
|
<< dlerror();
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
cache().kernels.insert({kernel_name, fun});
|
cache.kernels.insert({kernel_name, fun});
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,9 +141,18 @@ inline void build_kernel(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::function<bool(size_t)>& is_constant,
|
const std::unordered_set<uintptr_t>& constant_ids,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
|
// All outputs should have the exact same shape and will be row contiguous
|
||||||
|
auto output_shape = outputs[0].shape();
|
||||||
|
auto output_strides = outputs[0].strides();
|
||||||
|
|
||||||
|
// Constants are scalars that are captured by value and cannot change
|
||||||
|
auto is_constant = [&constant_ids](const array& x) {
|
||||||
|
return constant_ids.find(x.id()) != constant_ids.end();
|
||||||
|
};
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
@@ -162,28 +161,25 @@ inline void build_kernel(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Start the kernel
|
// Start the kernel
|
||||||
os << "void " << kernel_name
|
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||||
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
|
||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
int strides_index = 1;
|
for (auto& x : inputs) {
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(i)) {
|
if (is_constant(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& x = inputs[i];
|
|
||||||
auto& xname = namer.get_name(x);
|
|
||||||
|
|
||||||
auto tstr = get_type_string(x.dtype());
|
auto tstr = get_type_string(x.dtype());
|
||||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||||
<< "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
// Scalars and contiguous need no strides
|
// Scalars and contiguous need no strides
|
||||||
if (!is_scalar(x) && !contiguous) {
|
if (!is_scalar(x) && !contiguous) {
|
||||||
os << " const int64_t* " << xname << "_strides = strides["
|
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
||||||
<< strides_index++ << "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,8 +189,10 @@ inline void build_kernel(
|
|||||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
// Add output size
|
// Add output strides and shape to extract the indices.
|
||||||
if (contiguous) {
|
if (!contiguous) {
|
||||||
|
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
||||||
|
} else {
|
||||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,11 +206,10 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read the inputs in tmps
|
// Read the inputs in tmps
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (auto& x : inputs) {
|
||||||
const auto& x = inputs[i];
|
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
if (is_constant(i)) {
|
if (is_constant(x)) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||||
print_constant(os, x);
|
print_constant(os, x);
|
||||||
os << ";" << std::endl;
|
os << ";" << std::endl;
|
||||||
@@ -236,7 +233,7 @@ inline void build_kernel(
|
|||||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||||
} else {
|
} else {
|
||||||
os << x.primitive().name();
|
x.primitive().print(os);
|
||||||
os << "()(";
|
os << "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||||
@@ -262,9 +259,8 @@ inline void build_kernel(
|
|||||||
} else {
|
} else {
|
||||||
for (int d = ndim - 1; d >= 0; --d) {
|
for (int d = ndim - 1; d >= 0; --d) {
|
||||||
// Update pointers
|
// Update pointers
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (auto& x : inputs) {
|
||||||
const auto& x = inputs[i];
|
if (is_constant(x) || is_scalar(x)) {
|
||||||
if (is_constant(i) || is_scalar(x)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
@@ -286,33 +282,65 @@ inline void build_kernel(
|
|||||||
void Compiled::eval_cpu(
|
void Compiled::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
|
if (kernel_lib_.empty()) {
|
||||||
|
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out which kernel we are using
|
||||||
|
auto& shape = outputs[0].shape();
|
||||||
|
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
|
||||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
// Handle all broadcasting and collect function input arguments
|
||||||
// handle all broadcasting.
|
|
||||||
auto [contiguous, shape, strides] =
|
|
||||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
|
||||||
|
|
||||||
// Collect function input arguments.
|
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
std::vector<std::vector<size_t>> strides;
|
||||||
if (is_constant_(i)) {
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
|
// Skip constants.
|
||||||
|
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
|
|
||||||
|
if (contiguous || is_scalar(x)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast the input to the output shape.
|
||||||
|
std::vector<size_t> xstrides;
|
||||||
|
int j = 0;
|
||||||
|
for (; j < shape.size() - x.ndim(); j++) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(outputs[0].strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||||
|
if (x.shape(i) == 1) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(outputs[0].strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(x.strides()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
strides.push_back(std::move(xstrides));
|
||||||
|
args.push_back(strides.back().data());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel name from the lib
|
// Get the kernel name from the lib
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
kernel_name += std::to_string(ndim);
|
kernel_name += std::to_string(shape.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the function
|
// Get the function
|
||||||
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
auto fn_ptr = compile(kernel_name, [&]() {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << get_kernel_preamble() << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
@@ -322,7 +350,7 @@ void Compiled::eval_cpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
is_constant_,
|
constant_ids_,
|
||||||
contiguous,
|
contiguous,
|
||||||
ndim);
|
ndim);
|
||||||
// Close extern "C"
|
// Close extern "C"
|
||||||
@@ -330,26 +358,26 @@ void Compiled::eval_cpu(
|
|||||||
return kernel.str();
|
return kernel.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
compiled_allocate_outputs(
|
||||||
|
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||||
|
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
encoder.set_output_array(x);
|
encoder.set_output_array(x);
|
||||||
}
|
}
|
||||||
if (contiguous) {
|
Shape out_shape;
|
||||||
|
if (!contiguous) {
|
||||||
|
out_shape = outputs[0].shape();
|
||||||
|
args.push_back((void*)out_shape.data());
|
||||||
|
} else {
|
||||||
args.push_back((void*)outputs[0].data_size());
|
args.push_back((void*)outputs[0].data_size());
|
||||||
}
|
}
|
||||||
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
auto fun = (void (*)(void**))fn_ptr;
|
||||||
encoder.dispatch([fun,
|
encoder.dispatch(
|
||||||
|
[fun,
|
||||||
args = std::move(args),
|
args = std::move(args),
|
||||||
strides = std::move(strides),
|
strides = std::move(strides),
|
||||||
shape = std::move(shape)]() mutable {
|
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||||
SmallVector<int64_t*> strides_ptrs;
|
|
||||||
for (auto& s : strides) {
|
|
||||||
strides_ptrs.push_back(s.data());
|
|
||||||
}
|
|
||||||
fun(shape.data(), strides_ptrs.data(), args.data());
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -22,8 +22,7 @@ void slow_conv_1D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -61,8 +60,7 @@ void slow_conv_1D(
|
|||||||
out_stride_O = out.strides()[2],
|
out_stride_O = out.strides()[2],
|
||||||
|
|
||||||
flip,
|
flip,
|
||||||
padding_lo = padding_lo[0],
|
padding = padding[0],
|
||||||
padding_hi = padding_hi[0],
|
|
||||||
wt_stride = wt_strides[0],
|
wt_stride = wt_strides[0],
|
||||||
wt_dilation = wt_dilation[0],
|
wt_dilation = wt_dilation[0],
|
||||||
in_dilation = in_dilation[0]]() mutable {
|
in_dilation = in_dilation[0]]() mutable {
|
||||||
@@ -79,7 +77,7 @@ void slow_conv_1D(
|
|||||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||||
|
|
||||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||||
int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
|
int ih = oh * wt_stride - padding + wh_flip * wt_dilation;
|
||||||
|
|
||||||
auto ih_div = std::div(ih, in_dilation);
|
auto ih_div = std::div(ih, in_dilation);
|
||||||
|
|
||||||
@@ -111,8 +109,7 @@ void slow_conv_2D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -123,14 +120,16 @@ void slow_conv_2D(
|
|||||||
encoder.set_input_array(wt);
|
encoder.set_input_array(wt);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
encoder.dispatch(
|
encoder.dispatch([st_wt_ptr = wt.data<T>(),
|
||||||
[st_wt_ptr = wt.data<T>(),
|
|
||||||
st_in_ptr = in.data<T>(),
|
st_in_ptr = in.data<T>(),
|
||||||
st_out_ptr = out.data<T>(),
|
st_out_ptr = out.data<T>(),
|
||||||
|
|
||||||
N = in.shape(0), // Batch size, should be the same as out.shape(0)
|
N = in.shape(
|
||||||
iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
0), // Batch size, should be the same as out.shape(0)
|
||||||
iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
iH = 1 +
|
||||||
|
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
||||||
|
iW = 1 +
|
||||||
|
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
||||||
C = in.shape(3), // In channels
|
C = in.shape(3), // In channels
|
||||||
oH = out.shape(1), // Output spatial dim
|
oH = out.shape(1), // Output spatial dim
|
||||||
oW = out.shape(2), // Output spatial dim
|
oW = out.shape(2), // Output spatial dim
|
||||||
@@ -156,8 +155,7 @@ void slow_conv_2D(
|
|||||||
out_stride_W = out.strides()[2],
|
out_stride_W = out.strides()[2],
|
||||||
out_stride_O = out.strides()[3],
|
out_stride_O = out.strides()[3],
|
||||||
|
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -165,11 +163,14 @@ void slow_conv_2D(
|
|||||||
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
||||||
|
|
||||||
const int O_per_group = O / groups;
|
const int O_per_group = O / groups;
|
||||||
auto pt_conv_no_checks =
|
auto pt_conv_no_checks = [&](const T* in_ptr,
|
||||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
const T* wt_ptr,
|
||||||
|
T* out_ptr,
|
||||||
|
int oh,
|
||||||
|
int ow) {
|
||||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
int ih_base = oh * wt_strides[0] - padding[0];
|
||||||
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
int iw_base = ow * wt_strides[1] - padding[1];
|
||||||
|
|
||||||
for (int g = 0; g < groups; ++g) {
|
for (int g = 0; g < groups; ++g) {
|
||||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||||
@@ -182,13 +183,10 @@ void slow_conv_2D(
|
|||||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||||
|
|
||||||
const T* wt_ptr_pt =
|
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||||
const T* in_ptr_pt =
|
|
||||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
|
||||||
|
|
||||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||||
++c) {
|
|
||||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||||
static_cast<float>(
|
static_cast<float>(
|
||||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||||
@@ -214,16 +212,14 @@ void slow_conv_2D(
|
|||||||
int f_wgt_jump_w =
|
int f_wgt_jump_w =
|
||||||
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||||
|
|
||||||
int f_out_jump_h =
|
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||||
std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||||
int f_out_jump_w =
|
|
||||||
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
|
||||||
|
|
||||||
std::vector<int> base_h(f_out_jump_h);
|
std::vector<int> base_h(f_out_jump_h);
|
||||||
std::vector<int> base_w(f_out_jump_w);
|
std::vector<int> base_w(f_out_jump_w);
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||||
int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
|
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
|
||||||
|
|
||||||
int wh_base = 0;
|
int wh_base = 0;
|
||||||
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
||||||
@@ -235,7 +231,7 @@ void slow_conv_2D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||||
int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
|
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
|
||||||
|
|
||||||
int ww_base = 0;
|
int ww_base = 0;
|
||||||
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
||||||
@@ -250,8 +246,8 @@ void slow_conv_2D(
|
|||||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
|
|
||||||
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
int ih_base = oh * wt_strides[0] - padding[0];
|
||||||
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
int iw_base = ow * wt_strides[1] - padding[1];
|
||||||
|
|
||||||
int wh_base = base_h[oh % f_out_jump_h];
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
int ww_base = base_w[ow % f_out_jump_w];
|
int ww_base = base_w[ow % f_out_jump_w];
|
||||||
@@ -274,8 +270,8 @@ void slow_conv_2D(
|
|||||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||||
|
|
||||||
const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
|
const T* in_ptr_pt =
|
||||||
iw_dil * in_stride_W;
|
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||||
|
|
||||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||||
++c) {
|
++c) {
|
||||||
@@ -296,21 +292,17 @@ void slow_conv_2D(
|
|||||||
};
|
};
|
||||||
|
|
||||||
int oH_border_0 = 0;
|
int oH_border_0 = 0;
|
||||||
int oH_border_1 = is_idil_one
|
int oH_border_1 =
|
||||||
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
|
||||||
: oH;
|
|
||||||
int oH_border_2 = std::max(
|
int oH_border_2 = std::max(
|
||||||
oH_border_1,
|
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||||
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
|
||||||
int oH_border_3 = oH;
|
int oH_border_3 = oH;
|
||||||
|
|
||||||
int oW_border_0 = 0;
|
int oW_border_0 = 0;
|
||||||
int oW_border_1 = is_idil_one
|
int oW_border_1 =
|
||||||
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
|
||||||
: oW;
|
|
||||||
int oW_border_2 = std::max(
|
int oW_border_2 = std::max(
|
||||||
oW_border_1,
|
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||||
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
|
||||||
int oW_border_3 = oW;
|
int oW_border_3 = oW;
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
@@ -359,8 +351,7 @@ void slow_conv_3D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -409,8 +400,7 @@ void slow_conv_3D(
|
|||||||
out_stride_H = out.strides()[2],
|
out_stride_H = out.strides()[2],
|
||||||
out_stride_W = out.strides()[3],
|
out_stride_W = out.strides()[3],
|
||||||
out_stride_O = out.strides()[4],
|
out_stride_O = out.strides()[4],
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -425,9 +415,9 @@ void slow_conv_3D(
|
|||||||
int oh,
|
int oh,
|
||||||
int ow) {
|
int ow) {
|
||||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
int id_base = od * wt_strides[0] - padding_lo[0];
|
int id_base = od * wt_strides[0] - padding[0];
|
||||||
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
int ih_base = oh * wt_strides[1] - padding[1];
|
||||||
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
int iw_base = ow * wt_strides[2] - padding[2];
|
||||||
|
|
||||||
for (int o = 0; o < O; ++o) {
|
for (int o = 0; o < O; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
@@ -488,7 +478,7 @@ void slow_conv_3D(
|
|||||||
std::vector<int> base_w(f_out_jump_w);
|
std::vector<int> base_w(f_out_jump_w);
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||||
int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
|
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
||||||
|
|
||||||
int wd_base = 0;
|
int wd_base = 0;
|
||||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||||
@@ -500,7 +490,7 @@ void slow_conv_3D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||||
int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
|
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
||||||
|
|
||||||
int wh_base = 0;
|
int wh_base = 0;
|
||||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||||
@@ -512,7 +502,7 @@ void slow_conv_3D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||||
int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
|
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
||||||
|
|
||||||
int ww_base = 0;
|
int ww_base = 0;
|
||||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||||
@@ -531,9 +521,9 @@ void slow_conv_3D(
|
|||||||
int ow) {
|
int ow) {
|
||||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
|
|
||||||
int id_base = od * wt_strides[0] - padding_lo[0];
|
int id_base = od * wt_strides[0] - padding[0];
|
||||||
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
int ih_base = oh * wt_strides[1] - padding[1];
|
||||||
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
int iw_base = ow * wt_strides[2] - padding[2];
|
||||||
|
|
||||||
int wd_base = base_d[od % f_out_jump_d];
|
int wd_base = base_d[od % f_out_jump_d];
|
||||||
int wh_base = base_h[oh % f_out_jump_h];
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
@@ -583,30 +573,24 @@ void slow_conv_3D(
|
|||||||
};
|
};
|
||||||
|
|
||||||
int oD_border_0 = 0;
|
int oD_border_0 = 0;
|
||||||
int oD_border_1 = is_idil_one
|
int oD_border_1 =
|
||||||
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
||||||
: oD;
|
|
||||||
int oD_border_2 = std::max(
|
int oD_border_2 = std::max(
|
||||||
oD_border_1,
|
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||||
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
|
||||||
int oD_border_3 = oD;
|
int oD_border_3 = oD;
|
||||||
|
|
||||||
int oH_border_0 = 0;
|
int oH_border_0 = 0;
|
||||||
int oH_border_1 = is_idil_one
|
int oH_border_1 =
|
||||||
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
||||||
: oH;
|
|
||||||
int oH_border_2 = std::max(
|
int oH_border_2 = std::max(
|
||||||
oH_border_1,
|
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||||
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
|
||||||
int oH_border_3 = oH;
|
int oH_border_3 = oH;
|
||||||
|
|
||||||
int oW_border_0 = 0;
|
int oW_border_0 = 0;
|
||||||
int oW_border_1 = is_idil_one
|
int oW_border_1 =
|
||||||
? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
|
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
||||||
: oW;
|
|
||||||
int oW_border_2 = std::max(
|
int oW_border_2 = std::max(
|
||||||
oW_border_1,
|
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||||
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
|
||||||
int oW_border_3 = oW;
|
int oW_border_3 = oW;
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
@@ -674,8 +658,7 @@ void dispatch_slow_conv_1D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -686,8 +669,7 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -698,8 +680,7 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -710,8 +691,7 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -727,8 +707,7 @@ void dispatch_slow_conv_2D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -739,8 +718,7 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -751,8 +729,7 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -763,8 +740,7 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -780,8 +756,7 @@ void dispatch_slow_conv_3D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -792,8 +767,7 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -804,8 +778,7 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -816,8 +789,7 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo,
|
padding,
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@@ -857,8 +829,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
@@ -877,16 +848,16 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
|
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
std::vector<array> temps;
|
std::vector<array> temps;
|
||||||
temps.push_back(array(0, conv_dtype));
|
temps.push_back(array(0, conv_dtype));
|
||||||
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
size_t data_offset = padding[0] * in_padded.strides()[1];
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@@ -895,7 +866,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
in_padded_slice.size(),
|
in_padded_slice.size(),
|
||||||
data_offset);
|
data_offset);
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||||
temps.push_back(in_padded_slice);
|
temps.push_back(in_padded_slice);
|
||||||
|
|
||||||
// Make strided view
|
// Make strided view
|
||||||
@@ -920,7 +891,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
// Materialize strided view
|
// Materialize strided view
|
||||||
Shape strided_reshape = {N * oH, wH * C};
|
Shape strided_reshape = {N * oH, wH * C};
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
copy(in_strided_view, in_strided, CopyType::General, stream);
|
||||||
temps.push_back(in_strided);
|
temps.push_back(in_strided);
|
||||||
|
|
||||||
// Check wt dtype and prepare
|
// Check wt dtype and prepare
|
||||||
@@ -938,13 +909,13 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
wt.size(),
|
wt.size(),
|
||||||
0);
|
0);
|
||||||
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
||||||
copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);
|
copy(wt_transpose, gemm_wt, CopyType::General, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||||
auto ctype =
|
auto ctype =
|
||||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
copy_cpu(wt, gemm_wt, ctype, stream);
|
copy(wt, gemm_wt, ctype, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -991,7 +962,127 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
|
|
||||||
// Copy results if needed
|
// Copy results if needed
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||||
|
}
|
||||||
|
encoder.add_temporaries(std::move(temps));
|
||||||
|
}
|
||||||
|
|
||||||
|
void explicit_gemm_conv_2D_cpu(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation,
|
||||||
|
Stream stream) {
|
||||||
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
|
const int iH = in.shape(1); // Input spatial dim
|
||||||
|
const int iW = in.shape(2); // Input spatial dim
|
||||||
|
const int oH = out.shape(1); // Output spatial dim
|
||||||
|
const int oW = out.shape(2); // Output spatial dim
|
||||||
|
const int O = wt.shape(0); // Out channels
|
||||||
|
const int C = wt.shape(3); // In channels
|
||||||
|
const int wH = wt.shape(1); // Weight spatial dim
|
||||||
|
const int wW = wt.shape(2); // Weight spatial dim
|
||||||
|
|
||||||
|
auto conv_dtype = out.dtype();
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
|
||||||
|
// Pad input
|
||||||
|
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||||
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
|
// Fill with zeros
|
||||||
|
std::vector<array> temps;
|
||||||
|
temps.push_back(array(0, conv_dtype));
|
||||||
|
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
|
// Pick input slice from padded
|
||||||
|
size_t data_offset =
|
||||||
|
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
|
||||||
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
|
in_padded_slice.copy_shared_buffer(
|
||||||
|
in_padded,
|
||||||
|
in_padded.strides(),
|
||||||
|
in_padded.flags(),
|
||||||
|
in_padded_slice.size(),
|
||||||
|
data_offset);
|
||||||
|
temps.push_back(in_padded_slice);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||||
|
|
||||||
|
// Make strided view
|
||||||
|
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
||||||
|
|
||||||
|
Strides strided_strides = {
|
||||||
|
in_padded.strides()[0],
|
||||||
|
in_padded.strides()[1] * wt_strides[0],
|
||||||
|
in_padded.strides()[2] * wt_strides[1],
|
||||||
|
in_padded.strides()[1],
|
||||||
|
in_padded.strides()[2],
|
||||||
|
in_padded.strides()[3]};
|
||||||
|
auto flags = in_padded.flags();
|
||||||
|
|
||||||
|
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||||
|
in_strided_view.copy_shared_buffer(
|
||||||
|
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||||
|
|
||||||
|
// Materialize strided view
|
||||||
|
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||||
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
|
copy(in_strided_view, in_strided, CopyType::General, stream);
|
||||||
|
temps.push_back(in_strided);
|
||||||
|
|
||||||
|
// Check wt dtype and prepare
|
||||||
|
auto gemm_wt = wt;
|
||||||
|
auto gemm_out = out;
|
||||||
|
|
||||||
|
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||||
|
auto ctype =
|
||||||
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
|
copy(wt, gemm_wt, ctype, stream);
|
||||||
|
temps.push_back(gemm_wt);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
|
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||||
|
temps.push_back(gemm_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in_strided);
|
||||||
|
encoder.set_input_array(gemm_wt);
|
||||||
|
encoder.set_output_array(gemm_out);
|
||||||
|
|
||||||
|
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
|
||||||
|
gemm_wt_ptr = gemm_wt.data<float>(),
|
||||||
|
gemm_out_ptr = gemm_out.data<float>(),
|
||||||
|
strided_reshape = std::move(strided_reshape),
|
||||||
|
O]() {
|
||||||
|
// Perform gemm
|
||||||
|
cblas_sgemm(
|
||||||
|
CblasRowMajor,
|
||||||
|
CblasNoTrans, // no trans A
|
||||||
|
CblasTrans, // transB
|
||||||
|
strided_reshape[0], // M
|
||||||
|
O, // N
|
||||||
|
strided_reshape[1], // K
|
||||||
|
1.0f, // alpha
|
||||||
|
in_strided_ptr,
|
||||||
|
strided_reshape[1], // lda
|
||||||
|
gemm_wt_ptr,
|
||||||
|
strided_reshape[1], // ldb
|
||||||
|
0.0f, // beta
|
||||||
|
gemm_out_ptr,
|
||||||
|
O // ldc
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Copy results if needed
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||||
}
|
}
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
@@ -1000,8 +1091,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const bool flip,
|
const bool flip,
|
||||||
@@ -1024,21 +1114,20 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
Shape padded_shape(in.shape().size());
|
Shape padded_shape(in.shape().size());
|
||||||
padded_shape.front() = N;
|
padded_shape.front() = N;
|
||||||
for (size_t i = 0; i < iDim.size(); i++) {
|
for (size_t i = 0; i < iDim.size(); i++) {
|
||||||
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||||
}
|
}
|
||||||
padded_shape.back() = C;
|
padded_shape.back() = C;
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
std::vector<array> temps = {array(0, conv_dtype)};
|
std::vector<array> temps = {array(0, conv_dtype)};
|
||||||
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
for (size_t i = 0; i < padding_lo.size(); i++) {
|
for (size_t i = 0; i < padding.size(); i++) {
|
||||||
data_offset += padding_lo[i] * in_padded.strides()[i + 1];
|
data_offset += padding[i] * in_padded.strides()[i + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@@ -1048,7 +1137,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
data_offset);
|
data_offset);
|
||||||
|
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||||
temps.push_back(in_padded_slice);
|
temps.push_back(in_padded_slice);
|
||||||
|
|
||||||
// Make strided view
|
// Make strided view
|
||||||
@@ -1087,7 +1176,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
copy(in_strided_view, in_strided, CopyType::General, stream);
|
||||||
temps.push_back(in_strided);
|
temps.push_back(in_strided);
|
||||||
|
|
||||||
// Check wt dtype and prepare
|
// Check wt dtype and prepare
|
||||||
@@ -1098,13 +1187,13 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
auto ctype =
|
auto ctype =
|
||||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
copy_cpu(wt, gemm_wt, ctype, stream);
|
copy(wt, gemm_wt, ctype, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (flip) {
|
if (flip) {
|
||||||
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
||||||
copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);
|
copy(gemm_wt, gemm_wt_, CopyType::Vector, stream);
|
||||||
temps.push_back(gemm_wt_);
|
temps.push_back(gemm_wt_);
|
||||||
|
|
||||||
// Calculate the total size of the spatial dimensions
|
// Calculate the total size of the spatial dimensions
|
||||||
@@ -1159,7 +1248,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
// Copy results if needed
|
// Copy results if needed
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||||
}
|
}
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
@@ -1172,8 +1261,7 @@ void conv_1D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -1182,40 +1270,22 @@ void conv_1D_cpu(
|
|||||||
const int groups = in.shape().back() / wt.shape().back();
|
const int groups = in.shape().back() / wt.shape().back();
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||||
return explicit_gemm_conv_1D_cpu(
|
return explicit_gemm_conv_1D_cpu(
|
||||||
in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
|
in, wt, out, padding, wt_strides, wt_dilation, stream);
|
||||||
}
|
}
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_1D(
|
return dispatch_slow_conv_1D(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
in_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_2D_cpu(
|
void conv_2D_cpu(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -1225,35 +1295,18 @@ void conv_2D_cpu(
|
|||||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
||||||
in_dilation[1] == 1 && groups == 1) {
|
in_dilation[1] == 1 && groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_2D(
|
return dispatch_slow_conv_2D(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
in_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_3D_cpu(
|
void conv_3D_cpu(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding_lo,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& padding_hi,
|
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@@ -1264,28 +1317,11 @@ void conv_3D_cpu(
|
|||||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
||||||
groups == 1) {
|
groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_3D(
|
return dispatch_slow_conv_3D(
|
||||||
in,
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||||
wt,
|
|
||||||
out,
|
|
||||||
padding_lo,
|
|
||||||
padding_hi,
|
|
||||||
wt_strides,
|
|
||||||
wt_dilation,
|
|
||||||
in_dilation,
|
|
||||||
flip,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -1302,8 +1338,7 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo_,
|
padding_,
|
||||||
padding_hi_,
|
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@@ -1316,8 +1351,7 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo_,
|
padding_,
|
||||||
padding_hi_,
|
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@@ -1330,8 +1364,7 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_lo_,
|
padding_,
|
||||||
padding_hi_,
|
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
|
|||||||
@@ -295,11 +295,7 @@ inline void copy_inplace_dispatch(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void copy_cpu_inplace(
|
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||||
const array& src,
|
|
||||||
array& dst,
|
|
||||||
CopyType ctype,
|
|
||||||
Stream stream) {
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_input_array(src);
|
encoder.set_input_array(src);
|
||||||
encoder.set_output_array(dst);
|
encoder.set_output_array(dst);
|
||||||
@@ -309,7 +305,7 @@ void copy_cpu_inplace(
|
|||||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||||
bool donated = set_copy_output_data(src, dst, ctype);
|
bool donated = set_copy_output_data(src, dst, ctype);
|
||||||
if (donated && src.dtype() == dst.dtype()) {
|
if (donated && src.dtype() == dst.dtype()) {
|
||||||
// If the output has the same type as the input then there is nothing to
|
// If the output has the same type as the input then there is nothing to
|
||||||
@@ -319,10 +315,10 @@ void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
|||||||
if (ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
ctype = CopyType::General;
|
ctype = CopyType::General;
|
||||||
}
|
}
|
||||||
copy_cpu_inplace(src, dst, ctype, stream);
|
copy_inplace(src, dst, ctype, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_cpu_inplace(
|
void copy_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
@@ -377,10 +373,4 @@ void copy_cpu_inplace(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
|
||||||
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
|
||||||
return arr_copy;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -10,14 +10,10 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
|
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||||
void copy_cpu_inplace(
|
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||||
const array& src,
|
|
||||||
array& dst,
|
|
||||||
CopyType ctype,
|
|
||||||
Stream stream);
|
|
||||||
|
|
||||||
void copy_cpu_inplace(
|
void copy_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
@@ -30,7 +26,4 @@ void copy_cpu_inplace(
|
|||||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||||
|
|
||||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
|
||||||
array contiguous_copy_cpu(const array& arr, Stream stream);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
|||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return {arr, false};
|
return {arr, false};
|
||||||
} else {
|
} else {
|
||||||
return {contiguous_copy_cpu(arr, stream), true};
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy(arr, arr_copy, CopyType::General, stream);
|
||||||
|
return {arr_copy, true};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -32,7 +34,8 @@ void AllReduce::eval_cpu(
|
|||||||
}
|
}
|
||||||
return in;
|
return in;
|
||||||
} else {
|
} else {
|
||||||
array arr_copy = contiguous_copy_cpu(in, s);
|
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
copy(in, arr_copy, CopyType::General, s);
|
||||||
out.copy_shared_buffer(arr_copy);
|
out.copy_shared_buffer(arr_copy);
|
||||||
return arr_copy;
|
return arr_copy;
|
||||||
}
|
}
|
||||||
@@ -95,9 +98,4 @@ void Recv::eval_cpu(
|
|||||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReduceScatter::eval_cpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
|
|
||||||
}
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
|||||||
@@ -1,281 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/backend/cpu/copy.h"
|
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
|
||||||
#include "mlx/backend/cpu/lapack.h"
|
|
||||||
#include "mlx/linalg.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
complex64_t to_complex(T r, T i) {
|
|
||||||
return {static_cast<float>(r), static_cast<float>(i)};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, class Enable = void>
|
|
||||||
struct EigWork {};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct EigWork<
|
|
||||||
T,
|
|
||||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
|
||||||
using O = complex64_t;
|
|
||||||
|
|
||||||
char jobl;
|
|
||||||
char jobr;
|
|
||||||
int N;
|
|
||||||
int lwork;
|
|
||||||
int info;
|
|
||||||
std::vector<array::Data> buffers;
|
|
||||||
|
|
||||||
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
|
||||||
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
|
|
||||||
T work;
|
|
||||||
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
|
||||||
int n_vecs_r = 1;
|
|
||||||
geev<T>(
|
|
||||||
&jobl,
|
|
||||||
&jobr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
|
|
||||||
if (compute_eigenvectors) {
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
|
|
||||||
}
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
|
||||||
}
|
|
||||||
|
|
||||||
void run(T* a, O* values, O* vectors) {
|
|
||||||
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
|
|
||||||
T* vec_tmp = nullptr;
|
|
||||||
if (vectors) {
|
|
||||||
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
|
|
||||||
}
|
|
||||||
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
|
|
||||||
|
|
||||||
int n_vecs_l = vectors ? N : 1;
|
|
||||||
int n_vecs_r = 1;
|
|
||||||
geev<T>(
|
|
||||||
&jobl,
|
|
||||||
&jobr,
|
|
||||||
&N,
|
|
||||||
a,
|
|
||||||
&N,
|
|
||||||
eig_tmp,
|
|
||||||
eig_tmp + N,
|
|
||||||
vectors ? vec_tmp : nullptr,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
work,
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (vectors) {
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
if (values[i].imag() != 0) {
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vectors[i * N + j] =
|
|
||||||
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
|
|
||||||
vectors[(i + 1) * N + j] =
|
|
||||||
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct EigWork<std::complex<float>> {
|
|
||||||
using T = std::complex<float>;
|
|
||||||
using R = float;
|
|
||||||
using O = T;
|
|
||||||
|
|
||||||
char jobl;
|
|
||||||
char jobr;
|
|
||||||
int N;
|
|
||||||
int lwork;
|
|
||||||
int lrwork;
|
|
||||||
int info;
|
|
||||||
std::vector<array::Data> buffers;
|
|
||||||
|
|
||||||
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
|
||||||
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
|
|
||||||
T work;
|
|
||||||
R rwork;
|
|
||||||
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
|
||||||
int n_vecs_r = 1;
|
|
||||||
geev<T>(
|
|
||||||
&jobl,
|
|
||||||
&jobr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&rwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work.real());
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
|
||||||
}
|
|
||||||
|
|
||||||
void run(T* a, T* values, T* vectors) {
|
|
||||||
int n_vecs_l = vectors ? N : 1;
|
|
||||||
int n_vecs_r = 1;
|
|
||||||
geev<T>(
|
|
||||||
&jobl,
|
|
||||||
&jobr,
|
|
||||||
&N,
|
|
||||||
a,
|
|
||||||
&N,
|
|
||||||
values,
|
|
||||||
vectors,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
|
||||||
&info);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void eig_impl(
|
|
||||||
array& a,
|
|
||||||
array& vectors,
|
|
||||||
array& values,
|
|
||||||
bool compute_eigenvectors,
|
|
||||||
Stream stream) {
|
|
||||||
auto a_ptr = a.data<T>();
|
|
||||||
auto val_ptr = values.data<complex64_t>();
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_output_array(values);
|
|
||||||
complex64_t* vec_ptr = nullptr;
|
|
||||||
if (compute_eigenvectors) {
|
|
||||||
encoder.set_output_array(vectors);
|
|
||||||
vec_ptr = vectors.data<complex64_t>();
|
|
||||||
}
|
|
||||||
encoder.dispatch([a_ptr,
|
|
||||||
val_ptr,
|
|
||||||
vec_ptr,
|
|
||||||
compute_eigenvectors,
|
|
||||||
N = vectors.shape(-1),
|
|
||||||
size = vectors.size()]() mutable {
|
|
||||||
char jobr = 'N';
|
|
||||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
|
||||||
|
|
||||||
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
|
||||||
work.run(a_ptr, val_ptr, vec_ptr);
|
|
||||||
a_ptr += N * N;
|
|
||||||
val_ptr += N;
|
|
||||||
if (vec_ptr) {
|
|
||||||
vec_ptr += N * N;
|
|
||||||
}
|
|
||||||
if (work.info != 0) {
|
|
||||||
std::stringstream msg;
|
|
||||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
|
||||||
<< work.info;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
encoder.add_temporary(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Eig::eval_cpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
const auto& a = inputs[0];
|
|
||||||
auto& values = outputs[0];
|
|
||||||
|
|
||||||
auto vectors = compute_eigenvectors_
|
|
||||||
? outputs[1]
|
|
||||||
: array(a.shape(), complex64, nullptr, {});
|
|
||||||
|
|
||||||
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
|
||||||
copy_cpu(
|
|
||||||
a,
|
|
||||||
a_copy,
|
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
|
||||||
stream());
|
|
||||||
|
|
||||||
values.set_data(allocator::malloc(values.nbytes()));
|
|
||||||
|
|
||||||
if (compute_eigenvectors_) {
|
|
||||||
// Set the strides and flags so the eigenvectors
|
|
||||||
// are in the columns of the output
|
|
||||||
auto flags = vectors.flags();
|
|
||||||
auto strides = vectors.strides();
|
|
||||||
auto ndim = a.ndim();
|
|
||||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
|
||||||
|
|
||||||
if (a.size() > 1) {
|
|
||||||
flags.row_contiguous = false;
|
|
||||||
if (ndim > 2) {
|
|
||||||
flags.col_contiguous = false;
|
|
||||||
} else {
|
|
||||||
flags.col_contiguous = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vectors.set_data(
|
|
||||||
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
|
||||||
}
|
|
||||||
switch (a.dtype()) {
|
|
||||||
case float32:
|
|
||||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
eig_impl<double>(
|
|
||||||
a_copy, vectors, values, compute_eigenvectors_, stream());
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
eig_impl<std::complex<float>>(
|
|
||||||
a_copy, vectors, values, compute_eigenvectors_, stream());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -12,25 +12,31 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T, class Enable = void>
|
|
||||||
struct EighWork {};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct EighWork<
|
void eigh_impl(
|
||||||
T,
|
array& vectors,
|
||||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
array& values,
|
||||||
using R = T;
|
const std::string& uplo,
|
||||||
|
bool compute_eigenvectors,
|
||||||
|
Stream stream) {
|
||||||
|
auto vec_ptr = vectors.data<T>();
|
||||||
|
auto eig_ptr = values.data<T>();
|
||||||
|
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
|
||||||
char jobz;
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
char uplo;
|
encoder.set_output_array(vectors);
|
||||||
int N;
|
encoder.set_output_array(values);
|
||||||
int lwork;
|
encoder.dispatch([vec_ptr,
|
||||||
int liwork;
|
eig_ptr,
|
||||||
|
jobz,
|
||||||
|
uplo = uplo[0],
|
||||||
|
N = vectors.shape(-1),
|
||||||
|
size = vectors.size()]() mutable {
|
||||||
|
// Work query
|
||||||
|
int lwork = -1;
|
||||||
|
int liwork = -1;
|
||||||
int info;
|
int info;
|
||||||
std::vector<array::Data> buffers;
|
{
|
||||||
|
|
||||||
EighWork(char jobz_, char uplo_, int N_)
|
|
||||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
|
||||||
T work;
|
T work;
|
||||||
int iwork;
|
int iwork;
|
||||||
syevd<T>(
|
syevd<T>(
|
||||||
@@ -47,132 +53,29 @@ struct EighWork<
|
|||||||
&info);
|
&info);
|
||||||
lwork = static_cast<int>(work);
|
lwork = static_cast<int>(work);
|
||||||
liwork = iwork;
|
liwork = iwork;
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void run(T* vectors, T* values) {
|
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||||
|
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
||||||
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
syevd<T>(
|
syevd<T>(
|
||||||
&jobz,
|
&jobz,
|
||||||
&uplo,
|
&uplo,
|
||||||
&N,
|
&N,
|
||||||
vectors,
|
vec_ptr,
|
||||||
&N,
|
&N,
|
||||||
values,
|
|
||||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct EighWork<std::complex<float>> {
|
|
||||||
using T = std::complex<float>;
|
|
||||||
using R = float;
|
|
||||||
|
|
||||||
char jobz;
|
|
||||||
char uplo;
|
|
||||||
int N;
|
|
||||||
int lwork;
|
|
||||||
int lrwork;
|
|
||||||
int liwork;
|
|
||||||
int info;
|
|
||||||
std::vector<array::Data> buffers;
|
|
||||||
|
|
||||||
EighWork(char jobz_, char uplo_, int N_)
|
|
||||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
|
||||||
T work;
|
|
||||||
R rwork;
|
|
||||||
int iwork;
|
|
||||||
heevd<T>(
|
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&rwork,
|
|
||||||
&lrwork,
|
|
||||||
&iwork,
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work.real());
|
|
||||||
lrwork = static_cast<int>(rwork);
|
|
||||||
liwork = iwork;
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
|
||||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
|
||||||
}
|
|
||||||
|
|
||||||
void run(T* vectors, R* values) {
|
|
||||||
heevd<T>(
|
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
vectors,
|
|
||||||
&N,
|
|
||||||
values,
|
|
||||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
|
||||||
&lrwork,
|
|
||||||
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
if (jobz == 'V') {
|
|
||||||
// We have pre-transposed the vectors but we also must conjugate them
|
|
||||||
// when they are complex.
|
|
||||||
//
|
|
||||||
// We could vectorize this but it is so fast in comparison to heevd that
|
|
||||||
// it doesn't really matter.
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
for (int j = 0; j < N; j++) {
|
|
||||||
*vectors = std::conj(*vectors);
|
|
||||||
vectors++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void eigh_impl(
|
|
||||||
array& vectors,
|
|
||||||
array& values,
|
|
||||||
const std::string& uplo,
|
|
||||||
bool compute_eigenvectors,
|
|
||||||
Stream stream) {
|
|
||||||
using R = typename EighWork<T>::R;
|
|
||||||
|
|
||||||
auto vec_ptr = vectors.data<T>();
|
|
||||||
auto eig_ptr = values.data<R>();
|
|
||||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_output_array(vectors);
|
|
||||||
encoder.set_output_array(values);
|
|
||||||
encoder.dispatch([vec_ptr,
|
|
||||||
eig_ptr,
|
eig_ptr,
|
||||||
jobz,
|
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||||
uplo = uplo[0],
|
&lwork,
|
||||||
N = vectors.shape(-1),
|
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||||
size = vectors.size()]() mutable {
|
&liwork,
|
||||||
// Work query
|
&info);
|
||||||
EighWork<T> work(jobz, uplo, N);
|
|
||||||
|
|
||||||
// Work loop
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
|
||||||
work.run(vec_ptr, eig_ptr);
|
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
eig_ptr += N;
|
eig_ptr += N;
|
||||||
if (work.info != 0) {
|
if (info != 0) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
<< work.info;
|
<< info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -196,7 +99,7 @@ void Eigh::eval_cpu(
|
|||||||
|
|
||||||
values.set_data(allocator::malloc(values.nbytes()));
|
values.set_data(allocator::malloc(values.nbytes()));
|
||||||
|
|
||||||
copy_cpu(
|
copy(
|
||||||
a,
|
a,
|
||||||
vectors,
|
vectors,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
@@ -228,10 +131,6 @@ void Eigh::eval_cpu(
|
|||||||
eigh_impl<double>(
|
eigh_impl<double>(
|
||||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
case complex64:
|
|
||||||
eigh_impl<std::complex<float>>(
|
|
||||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
@@ -48,15 +49,9 @@ void matmul_bnns(
|
|||||||
size_t K = a_shape[ndim - 1];
|
size_t K = a_shape[ndim - 1];
|
||||||
|
|
||||||
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
||||||
|
|
||||||
#pragma GCC diagnostic push
|
#pragma GCC diagnostic push
|
||||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||||
if (beta != 1.0 && beta != 0.0) {
|
|
||||||
// scale the output
|
|
||||||
for (auto i = 0; i < batch_size * M * N; ++i) {
|
|
||||||
out[i] *= beta;
|
|
||||||
}
|
|
||||||
beta = 1.0;
|
|
||||||
}
|
|
||||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||||
/* float alpha = */ alpha,
|
/* float alpha = */ alpha,
|
||||||
/* float beta = */ beta,
|
/* float beta = */ beta,
|
||||||
|
|||||||
@@ -88,47 +88,4 @@ void matmul<double>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
|
||||||
void matmul<complex64_t>(
|
|
||||||
const complex64_t* a,
|
|
||||||
const complex64_t* b,
|
|
||||||
complex64_t* out,
|
|
||||||
bool a_transposed,
|
|
||||||
bool b_transposed,
|
|
||||||
size_t lda,
|
|
||||||
size_t ldb,
|
|
||||||
size_t ldc,
|
|
||||||
float alpha,
|
|
||||||
float beta,
|
|
||||||
size_t batch_size,
|
|
||||||
const Shape& a_shape,
|
|
||||||
const Strides& a_strides,
|
|
||||||
const Shape& b_shape,
|
|
||||||
const Strides& b_strides) {
|
|
||||||
auto ndim = a_shape.size();
|
|
||||||
size_t M = a_shape[ndim - 2];
|
|
||||||
size_t N = b_shape[ndim - 1];
|
|
||||||
size_t K = a_shape[ndim - 1];
|
|
||||||
auto calpha = static_cast<complex64_t>(alpha);
|
|
||||||
auto cbeta = static_cast<complex64_t>(beta);
|
|
||||||
|
|
||||||
for (int i = 0; i < batch_size; ++i) {
|
|
||||||
cblas_cgemm(
|
|
||||||
CblasRowMajor,
|
|
||||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
|
||||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
&calpha,
|
|
||||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
|
||||||
lda,
|
|
||||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
|
||||||
ldb,
|
|
||||||
&cbeta,
|
|
||||||
out + M * N * i,
|
|
||||||
ldc);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy_cpu(
|
copy(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -257,11 +257,15 @@ void gather_axis(
|
|||||||
const array& ind,
|
const array& ind,
|
||||||
array& out,
|
array& out,
|
||||||
const int axis) {
|
const int axis) {
|
||||||
auto shape = remove_index(ind.shape(), axis);
|
auto strides = ind.strides();
|
||||||
ContiguousIterator ind_it(
|
strides.erase(strides.begin() + axis);
|
||||||
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
auto shape = ind.shape();
|
||||||
ContiguousIterator src_it(
|
shape.erase(shape.begin() + axis);
|
||||||
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
||||||
|
|
||||||
|
strides = src.strides();
|
||||||
|
strides.erase(strides.begin() + axis);
|
||||||
|
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
||||||
|
|
||||||
auto ind_ptr = ind.data<IdxT>();
|
auto ind_ptr = ind.data<IdxT>();
|
||||||
auto src_ptr = src.data<T>();
|
auto src_ptr = src.data<T>();
|
||||||
@@ -517,7 +521,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy_cpu(src, out, ctype, stream());
|
copy(src, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
std::vector<array> inds;
|
std::vector<array> inds;
|
||||||
@@ -581,11 +585,15 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
template <typename T, typename IdxT, typename OpT>
|
template <typename T, typename IdxT, typename OpT>
|
||||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||||
auto shape = remove_index(idx.shape(), axis);
|
auto strides = idx.strides();
|
||||||
ContiguousIterator idx_it(
|
strides.erase(strides.begin() + axis);
|
||||||
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
auto shape = idx.shape();
|
||||||
ContiguousIterator upd_it(
|
shape.erase(shape.begin() + axis);
|
||||||
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
||||||
|
|
||||||
|
strides = upd.strides();
|
||||||
|
strides.erase(strides.begin() + axis);
|
||||||
|
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
||||||
|
|
||||||
auto idx_ptr = idx.data<IdxT>();
|
auto idx_ptr = idx.data<IdxT>();
|
||||||
auto upd_ptr = upd.data<T>();
|
auto upd_ptr = upd.data<T>();
|
||||||
@@ -686,7 +694,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy_cpu(src, out, ctype, stream());
|
copy(src, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(idx);
|
encoder.set_input_array(idx);
|
||||||
@@ -747,108 +755,4 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void masked_scatter_impl(const array& mask, const array& src, array& out) {
|
|
||||||
ContiguousIterator mask_it(mask);
|
|
||||||
ContiguousIterator src_it(src);
|
|
||||||
ContiguousIterator out_it(out);
|
|
||||||
|
|
||||||
const bool* mask_ptr = mask.data<bool>();
|
|
||||||
const T* src_ptr = src.data<T>();
|
|
||||||
T* dst_ptr = out.data<T>();
|
|
||||||
|
|
||||||
const size_t batch_count = mask.shape(0);
|
|
||||||
const size_t mask_batch_size = mask.size() / batch_count;
|
|
||||||
const size_t src_batch_size = src.size() / batch_count;
|
|
||||||
|
|
||||||
for (uint b = 0; b < batch_count; ++b) {
|
|
||||||
size_t src_consumed = 0;
|
|
||||||
src_it.seek(b * src_batch_size);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < mask_batch_size; ++i) {
|
|
||||||
if (mask_ptr[mask_it.loc]) {
|
|
||||||
if (src_consumed >= src_batch_size) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
|
|
||||||
}
|
|
||||||
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
|
|
||||||
src_it.step();
|
|
||||||
++src_consumed;
|
|
||||||
}
|
|
||||||
mask_it.step();
|
|
||||||
out_it.step();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 3);
|
|
||||||
|
|
||||||
auto& dst = inputs[0];
|
|
||||||
auto& mask = inputs[1];
|
|
||||||
auto& src = inputs[2];
|
|
||||||
|
|
||||||
// Copy src into out (copy allocates memory for out)
|
|
||||||
auto ctype =
|
|
||||||
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
|
||||||
copy_cpu(dst, out, ctype, stream());
|
|
||||||
|
|
||||||
if (mask.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
|
||||||
encoder.set_input_array(mask);
|
|
||||||
encoder.set_input_array(src);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
|
|
||||||
src = array::unsafe_weak_copy(src),
|
|
||||||
out = array::unsafe_weak_copy(out)]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
masked_scatter_impl<bool>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
masked_scatter_impl<uint8_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
masked_scatter_impl<uint16_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
masked_scatter_impl<uint32_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
masked_scatter_impl<uint64_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
masked_scatter_impl<int8_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
masked_scatter_impl<int16_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
masked_scatter_impl<int32_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
masked_scatter_impl<int64_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
masked_scatter_impl<float16_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
masked_scatter_impl<float>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
masked_scatter_impl<double>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
masked_scatter_impl<bfloat16_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
masked_scatter_impl<complex64_t>(mask, src, out);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ void inverse_impl(
|
|||||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||||
|
|
||||||
// The inverse is computed in place, so just copy the input to the output.
|
// The inverse is computed in place, so just copy the input to the output.
|
||||||
copy_cpu(
|
copy(
|
||||||
a,
|
a,
|
||||||
inv,
|
inv,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cpu/jit_compiler.h"
|
#include "mlx/backend/cpu/jit_compiler.h"
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user