mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
Compare commits
262 Commits
v0.12.2
...
quantized-
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f5b0f11968 | ||
![]() |
b509c2ad76 | ||
![]() |
852336b8a2 | ||
![]() |
6649244686 | ||
![]() |
047a584e3d | ||
![]() |
ef14b1e9c3 | ||
![]() |
5824626c0b | ||
![]() |
8e88e30d95 | ||
![]() |
0eb56d5be0 | ||
![]() |
f70764a162 | ||
![]() |
dad1b00b13 | ||
![]() |
430ffef58a | ||
![]() |
3d17077187 | ||
![]() |
c9b41d460f | ||
![]() |
32972a5924 | ||
![]() |
f6afb9c09b | ||
![]() |
3ddc07e936 | ||
![]() |
c26208f67d | ||
![]() |
d15fa13daf | ||
![]() |
58a855682c | ||
![]() |
92d7cb71f8 | ||
![]() |
50d8bed468 | ||
![]() |
9dd72cd421 | ||
![]() |
343aa46b78 | ||
![]() |
b8ab89b413 | ||
![]() |
f9f8c167d4 | ||
![]() |
3f86399922 | ||
![]() |
2b8ace6a03 | ||
![]() |
0ab8e099e8 | ||
![]() |
020f048cd0 | ||
![]() |
881615b072 | ||
![]() |
0eef4febfd | ||
![]() |
b54a70ec2d | ||
![]() |
bf6ec92216 | ||
![]() |
c21331d47f | ||
![]() |
e1c9600da3 | ||
![]() |
1fa0d20a30 | ||
![]() |
3274c6a087 | ||
![]() |
9b12093739 | ||
![]() |
f374b6ca4d | ||
![]() |
0070e1db40 | ||
![]() |
95d04805b3 | ||
![]() |
e4534dac17 | ||
![]() |
fef3c4ec1d | ||
![]() |
1bdc038bf9 | ||
![]() |
5523d9c426 | ||
![]() |
d878015228 | ||
![]() |
5900e3249f | ||
![]() |
bacced53d3 | ||
![]() |
4a64d4bff1 | ||
![]() |
b1e2b53c2d | ||
![]() |
11354d5bff | ||
![]() |
718aea3f1d | ||
![]() |
5b6f38df2b | ||
![]() |
0b4a58699e | ||
![]() |
4f9f9ebb6f | ||
![]() |
afc9c0ec1b | ||
![]() |
195b429d99 | ||
![]() |
2b878e9dd7 | ||
![]() |
67b6bf530d | ||
![]() |
6af5ca35b2 | ||
![]() |
4f46e9c997 | ||
![]() |
c6739ba7f3 | ||
![]() |
914409fef9 | ||
![]() |
8d68a3e805 | ||
![]() |
6bbcc453ef | ||
![]() |
d5ed4d7a71 | ||
![]() |
669c27140d | ||
![]() |
adcc88e208 | ||
![]() |
d6492b0163 | ||
![]() |
b3f52c9fbe | ||
![]() |
bd8396fad8 | ||
![]() |
d0c58841d1 | ||
![]() |
881f09b2e2 | ||
![]() |
8b30acd7eb | ||
![]() |
02efb310ca | ||
![]() |
e7e59c6f05 | ||
![]() |
3ae6aabe9f | ||
![]() |
dc627dcb5e | ||
![]() |
efeb9c0f02 | ||
![]() |
ba3e913c7a | ||
![]() |
7cca1727af | ||
![]() |
11371fe251 | ||
![]() |
41c603d48a | ||
![]() |
969337345f | ||
![]() |
9592766939 | ||
![]() |
58dca7d846 | ||
![]() |
0d302cd25b | ||
![]() |
da691257ec | ||
![]() |
1600092e92 | ||
![]() |
dba2bd1105 | ||
![]() |
28be4de7c2 | ||
![]() |
a6c3b38fba | ||
![]() |
fcb65a3897 | ||
![]() |
4e22a1dffe | ||
![]() |
291cf40aca | ||
![]() |
bd47e1f066 | ||
![]() |
e6b223df5f | ||
![]() |
e64349bbdd | ||
![]() |
cdb59faea6 | ||
![]() |
1d94ac3f90 | ||
![]() |
5f7d19d1f5 | ||
![]() |
2fdf9eb535 | ||
![]() |
860d3a50d7 | ||
![]() |
d1183821a7 | ||
![]() |
8081df79be | ||
![]() |
64bec4fad7 | ||
![]() |
b96e105244 | ||
![]() |
3b4d5484c7 | ||
![]() |
684e11c664 | ||
![]() |
b57a52813b | ||
![]() |
da8deb2b62 | ||
![]() |
98b6ce3460 | ||
![]() |
f9e00efe31 | ||
![]() |
0fd2a1f4b0 | ||
![]() |
df3233454d | ||
![]() |
82db84b899 | ||
![]() |
8ae751d3da | ||
![]() |
d40e76809f | ||
![]() |
bb1b76d9dc | ||
![]() |
9d26441224 | ||
![]() |
f12f24a77c | ||
![]() |
ae5b5cabfd | ||
![]() |
d0630ffe8c | ||
![]() |
99bb7d3a58 | ||
![]() |
63ae767232 | ||
![]() |
eaaea02010 | ||
![]() |
a098bc92e0 | ||
![]() |
1086dc4db0 | ||
![]() |
19fb69e2ed | ||
![]() |
9231617eb3 | ||
![]() |
32668a7317 | ||
![]() |
780c197f95 | ||
![]() |
eb8819e91e | ||
![]() |
30bbea2f08 | ||
![]() |
635ccd9e25 | ||
![]() |
8c9f0278b9 | ||
![]() |
58d0e199e1 | ||
![]() |
10b5835501 | ||
![]() |
6c8dd307eb | ||
![]() |
43ffdab172 | ||
![]() |
40b6d67333 | ||
![]() |
c52d1600f0 | ||
![]() |
aa1d6cadad | ||
![]() |
6e06e3a904 | ||
![]() |
8cfb9fc0b8 | ||
![]() |
7b456fd2c0 | ||
![]() |
e9e53856d2 | ||
![]() |
5029894662 | ||
![]() |
baf9fa5f42 | ||
![]() |
7f914365fd | ||
![]() |
ebd7135b50 | ||
![]() |
50eff6a10a | ||
![]() |
c34a5ae7f7 | ||
![]() |
e2aa6ec8ae | ||
![]() |
6768c6a54a | ||
![]() |
6307d166eb | ||
![]() |
1fba87b0df | ||
![]() |
df124e018a | ||
![]() |
2f83d6e4b7 | ||
![]() |
987785d8d7 | ||
![]() |
8c01a7893b | ||
![]() |
218047c75a | ||
![]() |
d0da74209b | ||
![]() |
5c1fa64fb0 | ||
![]() |
a3c287354f | ||
![]() |
03cf033f82 | ||
![]() |
bdb36c9a63 | ||
![]() |
20bb301195 | ||
![]() |
d6383a1c6a | ||
![]() |
b05bcfd27f | ||
![]() |
2615660e62 | ||
![]() |
5b0af4cdb1 | ||
![]() |
8c2e15e6c8 | ||
![]() |
56c8a33439 | ||
![]() |
4eef1e8a3e | ||
![]() |
95d11bda06 | ||
![]() |
af9079cc1f | ||
![]() |
2d6cd47713 | ||
![]() |
fe3167d7ea | ||
![]() |
31e134be35 | ||
![]() |
e84ba8056d | ||
![]() |
f20e97b092 | ||
![]() |
934683088e | ||
![]() |
de2b9e7d0a | ||
![]() |
dd7d8e5e29 | ||
![]() |
df964132fb | ||
![]() |
709ccc6800 | ||
![]() |
cf236fc390 | ||
![]() |
27d70c7d9d | ||
![]() |
0e585b4409 | ||
![]() |
0163a8e57a | ||
![]() |
578842954c | ||
![]() |
496315fe1d | ||
![]() |
0fe6895893 | ||
![]() |
0b7d71fd2f | ||
![]() |
83b11bc58d | ||
![]() |
375a8bbdcc | ||
![]() |
ea9090bbc4 | ||
![]() |
81def6ac76 | ||
![]() |
3de8ce3f3c | ||
![]() |
4d485fca24 | ||
![]() |
1865299a30 | ||
![]() |
3576b547c5 | ||
![]() |
079882495d | ||
![]() |
ab977109db | ||
![]() |
fd1c08137b | ||
![]() |
76b6cece46 | ||
![]() |
9f0df51f8d | ||
![]() |
e7a2a3dcd1 | ||
![]() |
a87ef5bfc1 | ||
![]() |
9f9cb7a2ef | ||
![]() |
7e26fd8032 | ||
![]() |
eab2685c67 | ||
![]() |
50dfb664db | ||
![]() |
0189ab6ab6 | ||
![]() |
9401507336 | ||
![]() |
eb8321d863 | ||
![]() |
79ef49b2c2 | ||
![]() |
e110ca11e2 | ||
![]() |
226748b3e7 | ||
![]() |
d568c7ee36 | ||
![]() |
e6fecbb3e1 | ||
![]() |
da83f899bb | ||
![]() |
7e5674d8be | ||
![]() |
0a558577bf | ||
![]() |
fb71a82ada | ||
![]() |
23406c9e9e | ||
![]() |
b3ec792380 | ||
![]() |
6a9b584f3d | ||
![]() |
81dd33af66 | ||
![]() |
8b76571896 | ||
![]() |
e78a6518fa | ||
![]() |
1873ffda01 | ||
![]() |
c417e42116 | ||
![]() |
358e1fd6ab | ||
![]() |
631dfbe673 | ||
![]() |
56a4eaed72 | ||
![]() |
bf925d9dc7 | ||
![]() |
1a7ed5dcb6 | ||
![]() |
5be5daa6ef | ||
![]() |
60cb11764e | ||
![]() |
cbd5445ea7 | ||
![]() |
2c7e9b5158 | ||
![]() |
2263e4b279 | ||
![]() |
863039da4c | ||
![]() |
7178ac0111 | ||
![]() |
e7f9710499 | ||
![]() |
ff4223904d | ||
![]() |
a9f80d60f6 | ||
![]() |
2e158cf6d0 | ||
![]() |
8bd6bfa4b5 | ||
![]() |
8b1906abd0 | ||
![]() |
06375e6605 | ||
![]() |
b21242faf1 | ||
![]() |
cc05a281c4 | ||
![]() |
fe96ceee66 | ||
![]() |
9814a2ae12 | ||
![]() |
6992498e7a | ||
![]() |
21623156a3 | ||
![]() |
79c859e2e0 | ||
![]() |
b00ac960b4 |
@@ -13,8 +13,62 @@ parameters:
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
parameters:
|
||||
upload-docs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
brew install python@3.9
|
||||
brew install doxygen
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
steps:
|
||||
- run:
|
||||
name: Build documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd docs && doxygen && make html O=-W
|
||||
- when:
|
||||
condition: << parameters.upload-docs >>
|
||||
steps:
|
||||
- add_ssh_keys:
|
||||
fingerprints:
|
||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
||||
- run:
|
||||
name: Upload documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
git config user.email "mlx@group.apple.com"
|
||||
git config user.name "CircleCI Docs"
|
||||
git checkout gh-pages
|
||||
git rebase main
|
||||
cd docs
|
||||
git rm -rf build/html
|
||||
doxygen && make html O=-W
|
||||
git add -f build/html
|
||||
git commit -m "rebase"
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
@@ -31,33 +85,35 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.2.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
# command: |
|
||||
# cd examples/extensions && python3 -m pip install .
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests
|
||||
@@ -69,18 +125,19 @@ jobs:
|
||||
default: "15.2.0"
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.large.gen1
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.8
|
||||
python3.8 -m venv env
|
||||
brew install python@3.9
|
||||
brew install openmpi
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.2.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -89,11 +146,12 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
@@ -101,23 +159,47 @@ jobs:
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
# command: |
|
||||
# cd examples/extensions && python3.11 -m pip install .
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
DEVICE=cpu ./build/tests/tests
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@@ -132,18 +214,19 @@ jobs:
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.large.gen1
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@<< parameters.python_version >>
|
||||
brew install openmpi
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.2.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
@@ -153,19 +236,20 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
@@ -178,7 +262,7 @@ jobs:
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_linux_test_release:
|
||||
build_linux_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
@@ -207,21 +291,28 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.2.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
@@ -239,8 +330,9 @@ workflows:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
- linux_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
@@ -257,9 +349,17 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
@@ -274,7 +374,7 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
@@ -286,7 +386,7 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
weekly_build:
|
||||
when:
|
||||
@@ -297,17 +397,17 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.test_release >>
|
||||
- << pipeline.parameters.linux_release >>
|
||||
jobs:
|
||||
- build_linux_test_release:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
2
.github/workflows/pull_request.yml
vendored
2
.github/workflows/pull_request.yml
vendored
@@ -17,4 +17,4 @@ jobs:
|
||||
pip install pre-commit black isort clang-format
|
||||
- name: Run lint
|
||||
run: |
|
||||
pre-commit run --all-files
|
||||
pre-commit run --all-files
|
||||
|
@@ -1,11 +1,11 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.4
|
||||
rev: v18.1.8
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.4.2
|
||||
rev: 24.8.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
@@ -14,3 +14,7 @@ repos:
|
||||
- id: isort
|
||||
args:
|
||||
- --profile=black
|
||||
- repo: https://github.com/cheshirekow/cmake-format-precommit
|
||||
rev: v0.6.13
|
||||
hooks:
|
||||
- id: cmake-format
|
||||
|
@@ -7,15 +7,18 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
24
CITATION.cff
Normal file
24
CITATION.cff
Normal file
@@ -0,0 +1,24 @@
|
||||
cff-version: 1.2.0
|
||||
title: mlx
|
||||
message: >-
|
||||
If you use this software, please cite it using the
|
||||
metadata from this file.
|
||||
type: software
|
||||
authors:
|
||||
- given-names: Awni
|
||||
family-names: Hannun
|
||||
affiliation: Apple
|
||||
- given-names: Jagrit
|
||||
family-names: Digani
|
||||
affiliation: Apple
|
||||
- given-names: Angelos
|
||||
family-names: Katharopoulos
|
||||
affiliation: Apple
|
||||
- given-names: Ronan
|
||||
family-names: Collobert
|
||||
affiliation: Apple
|
||||
repository-code: 'https://github.com/ml-explore'
|
||||
abstract: >-
|
||||
MLX: efficient and flexible machine learning on Apple
|
||||
silicon
|
||||
license: MIT
|
289
CMakeLists.txt
289
CMakeLists.txt
@@ -15,40 +15,52 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.12.2)
|
||||
set(MLX_VERSION 0.19.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||
message(
|
||||
STATUS
|
||||
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||
)
|
||||
|
||||
set(MLX_BUILD_ARM OFF)
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
if(NOT MLX_ENABLE_X64_MAC)
|
||||
message(FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
||||
)
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
|
||||
include(FetchContent)
|
||||
@@ -57,170 +69,192 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
if (MLX_BUILD_METAL)
|
||||
find_library(METAL_LIB Metal)
|
||||
find_library(FOUNDATION_LIB Foundation)
|
||||
find_library(QUARTZ_LIB QuartzCore)
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_LIB "-framework Metal")
|
||||
set(FOUNDATION_LIB "-framework Foundation")
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif (MLX_BUILD_METAL)
|
||||
elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
|
||||
if (MLX_METAL_DEBUG)
|
||||
if(MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
COMMAND_ERROR_IS_FATAL ANY)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_VERSION} LESS 14.0)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||
endif()
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
else()
|
||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
metal_cpp
|
||||
URL ${METAL_CPP_URL}
|
||||
PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||
)
|
||||
# Get the metal version
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
mlx PUBLIC
|
||||
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
||||
)
|
||||
target_link_libraries(
|
||||
mlx
|
||||
${METAL_LIB}
|
||||
${FOUNDATION_LIB}
|
||||
${QUARTZ_LIB})
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
|
||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||
endif()
|
||||
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
if(MLX_BUILD_CPU)
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT
|
||||
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if(NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old
|
||||
# version of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if(NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
endif()
|
||||
# TODO find a cleaner way to do this
|
||||
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if(MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET)
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif(MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||
# of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
endif()
|
||||
# TODO find a cleaner way to do this
|
||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
target_include_directories(
|
||||
mlx
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
find_package(
|
||||
Python 3.8
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_TESTS)
|
||||
if(MLX_BUILD_TESTS)
|
||||
include(CTest)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_EXAMPLES)
|
||||
if(MLX_BUILD_EXAMPLES)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_BENCHMARKS)
|
||||
if(MLX_BUILD_BENCHMARKS)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
# ----------------------------- Installation -----------------------------
|
||||
include(GNUInstallDirs)
|
||||
|
||||
# Install library
|
||||
install(
|
||||
TARGETS mlx
|
||||
EXPORT MLXTargets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
|
||||
TARGETS mlx
|
||||
EXPORT MLXTargets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
INCLUDES
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
# Install headers
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
COMPONENT headers
|
||||
FILES_MATCHING PATTERN "*.h"
|
||||
)
|
||||
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
COMPONENT headers
|
||||
FILES_MATCHING
|
||||
PATTERN "*.h"
|
||||
PATTERN "backend/metal/kernels.h" EXCLUDE)
|
||||
|
||||
# Install metal dependencies
|
||||
if (MLX_BUILD_METAL)
|
||||
if(MLX_BUILD_METAL)
|
||||
|
||||
# Install metal cpp
|
||||
install(
|
||||
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||
COMPONENT metal_cpp_source
|
||||
)
|
||||
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||
COMPONENT metal_cpp_source)
|
||||
|
||||
endif()
|
||||
|
||||
@@ -232,31 +266,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
||||
install(
|
||||
EXPORT MLXTargets
|
||||
FILE MLXTargets.cmake
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||
COMPATIBILITY SameMajorVersion
|
||||
VERSION ${MLX_VERSION}
|
||||
)
|
||||
VERSION ${MLX_VERSION})
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
|
||||
${MLX_CMAKE_BUILD_CONFIG}
|
||||
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
|
||||
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
|
||||
)
|
||||
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
|
||||
MLX_CMAKE_INSTALL_MODULE_DIR)
|
||||
|
||||
install(
|
||||
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
||||
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||
|
||||
install(
|
||||
DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
||||
install(DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||
|
@@ -88,13 +88,13 @@ for more information on building the C++ and Python APIs from source.
|
||||
|
||||
## Contributing
|
||||
|
||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
||||
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||
on contributing to MLX. See the
|
||||
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
||||
information on building from source, and running tests.
|
||||
|
||||
We are grateful for all of [our
|
||||
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
to MLX and wish to be acknowledged, please add your name to the list in your
|
||||
pull request.
|
||||
|
||||
|
@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
return torch.nn.functional.mish(y)
|
||||
y = torch.nn.functional.mish(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@@ -283,6 +283,14 @@ def topk(axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step_function(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.where(y < 0, 0, 1)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def selu(x):
|
||||
y = x
|
||||
@@ -446,5 +454,11 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
elif args.benchmark == "step":
|
||||
print(bench(step_function, x))
|
||||
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
|
||||
result = run(*args, capture_output=True, **kwargs)
|
||||
return float(result.stdout)
|
||||
except ValueError:
|
||||
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
||||
raise ValueError(
|
||||
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
|
||||
)
|
||||
|
||||
|
||||
def compare(args):
|
||||
|
@@ -9,7 +9,6 @@ from time_utils import time_fn
|
||||
|
||||
|
||||
def bench_gelu():
|
||||
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
@@ -51,7 +50,6 @@ def bench_gelu():
|
||||
|
||||
|
||||
def bench_layernorm():
|
||||
|
||||
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
mx.eval(weight, bias)
|
||||
|
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn
|
||||
import mlx.optimizers as opt
|
||||
import torch
|
||||
|
||||
|
||||
def bench_mlx(steps: int = 20) -> float:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
class BenchNetMLX(mlx.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=32):
|
||||
super().__init__()
|
||||
|
||||
self.net = mlx.nn.Sequential(
|
||||
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.Conv2d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose2d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose2d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetMLX(3)
|
||||
mx.eval(benchNet.parameters())
|
||||
optim = opt.Adam(learning_rate=1e-3)
|
||||
|
||||
inputs = mx.random.normal([10, 256, 256, 3])
|
||||
|
||||
params = benchNet.parameters()
|
||||
optim.init(params)
|
||||
|
||||
state = [benchNet.state, optim.state]
|
||||
|
||||
def loss_fn(params, image):
|
||||
benchNet.update(params)
|
||||
pred_image = benchNet(image)
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
def step(params, image):
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||
optim.update(benchNet, grads)
|
||||
return loss
|
||||
|
||||
total_time = 0.0
|
||||
print("MLX:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
step(benchNet.parameters(), inputs)
|
||||
mx.eval(state)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def bench_torch(steps: int = 20) -> float:
|
||||
device = torch.device("cpu")
|
||||
|
||||
class BenchNetTorch(torch.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=32):
|
||||
super().__init__()
|
||||
|
||||
self.net = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose2d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose2d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetTorch(3).to(device)
|
||||
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||
|
||||
inputs = torch.randn(10, 3, 256, 256, device=device)
|
||||
|
||||
def loss_fn(pred_image, image):
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
total_time = 0.0
|
||||
print("PyTorch:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
optim.zero_grad()
|
||||
pred_image = benchNet(inputs)
|
||||
loss = loss_fn(pred_image, inputs)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def main():
|
||||
steps = 20
|
||||
time_mlx = bench_mlx(steps)
|
||||
time_torch = bench_torch(steps)
|
||||
|
||||
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_transpose_2D
|
||||
|
||||
|
||||
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_transpose_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||
)
|
||||
out_pt = torch.conv_transpose2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_3D
|
||||
|
||||
|
||||
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_3D
|
||||
|
||||
|
||||
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv3d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn
|
||||
import mlx.optimizers as opt
|
||||
import torch
|
||||
|
||||
|
||||
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
class BenchNetMLX(mlx.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.net = mlx.nn.Sequential(
|
||||
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.Conv3d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose3d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose3d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetMLX(3)
|
||||
mx.eval(benchNet.parameters())
|
||||
optim = opt.Adam(learning_rate=1e-3)
|
||||
|
||||
inputs = mx.random.normal(shape)
|
||||
|
||||
params = benchNet.parameters()
|
||||
optim.init(params)
|
||||
|
||||
state = [benchNet.state, optim.state]
|
||||
|
||||
def loss_fn(params, image):
|
||||
benchNet.update(params)
|
||||
pred_image = benchNet(image)
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
def step(params, image):
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||
optim.update(benchNet, grads)
|
||||
return loss
|
||||
|
||||
total_time = 0.0
|
||||
print("MLX:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
step(benchNet.parameters(), inputs)
|
||||
mx.eval(state)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
|
||||
device = torch.device("cpu")
|
||||
|
||||
class BenchNetTorch(torch.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.net = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv3d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose3d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose3d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetTorch(3).to(device)
|
||||
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||
|
||||
inputs = torch.randn(*shape, device=device)
|
||||
|
||||
def loss_fn(pred_image, image):
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
total_time = 0.0
|
||||
print("PyTorch:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
optim.zero_grad()
|
||||
pred_image = benchNet(inputs)
|
||||
loss = loss_fn(pred_image, inputs)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def main():
|
||||
steps = 10
|
||||
time_mlx = bench_mlx(steps)
|
||||
time_torch = bench_torch(steps)
|
||||
|
||||
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||
def mx_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose3d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_3D
|
||||
|
||||
|
||||
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose3d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_3D
|
||||
|
||||
|
||||
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose3d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.conv_transpose3d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@@ -28,11 +28,11 @@ def bench(f, a, b):
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
@@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
@@ -53,11 +53,12 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
@@ -67,15 +68,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding)
|
||||
f_pt = make_pt_conv_2D(strides, padding)
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
@@ -84,7 +85,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
@@ -95,35 +96,40 @@ if __name__ == "__main__":
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
|
||||
for N, H, W, C, kH, kW, O, strides, padding in shapes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, np_dtype
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
135
benchmarks/python/conv_transpose_bench.py
Normal file
135
benchmarks/python/conv_transpose_bench.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_transpose_2D
|
||||
|
||||
|
||||
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_transpose_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.conv_transpose2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
66
benchmarks/python/distributed_bench.py
Normal file
66
benchmarks/python/distributed_bench.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Run with:
|
||||
mpirun -n 2 python /path/to/distributed_bench.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def time_fn(fn, *args, **kwargs):
|
||||
msg = kwargs.pop("msg", None)
|
||||
world = mx.distributed.init()
|
||||
if world.rank() == 0:
|
||||
if msg:
|
||||
print(f"Timing {msg} ...", end=" ")
|
||||
else:
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(*args, **kwargs))
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = mx.eval(fn(*args, **kwargs))
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
if world.rank() == 0:
|
||||
print(f"{msec:.5f} msec")
|
||||
|
||||
|
||||
def time_all_sum():
|
||||
shape = (4096,)
|
||||
x = mx.random.uniform(shape=shape)
|
||||
mx.eval(x)
|
||||
|
||||
def sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
return x
|
||||
|
||||
time_fn(sine, x)
|
||||
|
||||
def all_sum_plain(x):
|
||||
for _ in range(20):
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_plain, x)
|
||||
|
||||
def all_sum_with_sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_with_sine, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_all_sum()
|
84
benchmarks/python/einsum_bench.py
Normal file
84
benchmarks/python/einsum_bench.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def timeit(fn, its=100, args=[]):
|
||||
for _ in range(5):
|
||||
fn(*args)
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
fn(*args)
|
||||
toc = time.perf_counter()
|
||||
return 1e3 * (toc - tic) / its
|
||||
|
||||
|
||||
def time_little_einsum_path():
|
||||
subscripts = "ik,kj->ij"
|
||||
x = mx.ones((32, 32))
|
||||
y = mx.ones((32, 32))
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
||||
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
||||
print("Timing little einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_big_einsum_path():
|
||||
chars = list("abcdefgh")
|
||||
char_to_dim = {c: v for v, c in enumerate(chars)}
|
||||
|
||||
num_inputs = 10
|
||||
inputs = []
|
||||
subscripts = []
|
||||
for _ in range(num_inputs):
|
||||
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
||||
subscripts.append("".join(subscript))
|
||||
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
||||
subscripts = ",".join(subscripts)
|
||||
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
||||
|
||||
inputs = [mx.array(x) for x in inputs]
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
||||
print("Timing big einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_attention():
|
||||
def regular_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
||||
mx.eval(output)
|
||||
|
||||
def einsum_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
||||
mx.eval(output)
|
||||
|
||||
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
||||
|
||||
regular_time = timeit(regular_attention, args=(x,))
|
||||
ein_time = timeit(einsum_attention, args=(x,))
|
||||
print("Timing einsum attention...")
|
||||
print(f"Regular ... {regular_time:.3f} ms")
|
||||
print(f"Einsum ... {ein_time:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_little_einsum_path()
|
||||
time_big_einsum_path()
|
||||
time_attention()
|
@@ -3,6 +3,8 @@
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import sympy
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@@ -16,41 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
|
||||
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||
|
||||
|
||||
def run_bench(system_size):
|
||||
def fft(x):
|
||||
out = mx.fft.fft(x)
|
||||
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
|
||||
def fft_mlx(x):
|
||||
if dim == 1:
|
||||
out = mx.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = mx.fft.fft2(x)
|
||||
mx.eval(out)
|
||||
return out
|
||||
|
||||
bandwidths = []
|
||||
for k in range(4, 12):
|
||||
n = 2**k
|
||||
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
|
||||
x = x.astype(mx.complex64)
|
||||
mx.eval(x)
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
|
||||
def fft_mps(x):
|
||||
if dim == 1:
|
||||
out = torch.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = torch.fft.fft2(x)
|
||||
torch.mps.synchronize()
|
||||
return out
|
||||
|
||||
return bandwidths
|
||||
bandwidths = []
|
||||
for n in fft_sizes:
|
||||
batch_size = system_size // n**dim
|
||||
shape = [batch_size] + [n for _ in range(dim)]
|
||||
if backend == "mlx":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = mx.array(x_np)
|
||||
mx.eval(x)
|
||||
fft = fft_mlx
|
||||
elif backend == "mps":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = torch.tensor(x_np, device="mps")
|
||||
torch.mps.synchronize()
|
||||
fft = fft_mps
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
|
||||
print(n, bandwidth)
|
||||
bandwidths.append(bandwidth)
|
||||
|
||||
return np.array(bandwidths)
|
||||
|
||||
|
||||
def time_fft():
|
||||
x = np.array(range(2, 512))
|
||||
system_size = int(2**26)
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||
|
||||
print("MLX GPU")
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=int(2**29))
|
||||
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
# plot bandwidths
|
||||
x = [2**k for k in range(4, 12)]
|
||||
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
|
||||
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
|
||||
plt.title("MLX FFT Benchmark")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig("fft_plot.png")
|
||||
print("MPS GPU")
|
||||
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
|
||||
|
||||
print("CPU")
|
||||
system_size = int(2**20)
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
x = np.array(x)
|
||||
|
||||
all_indices = x - x[0]
|
||||
radix_2to13 = (
|
||||
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
bluesteins = (
|
||||
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
|
||||
for indices, name in [
|
||||
(all_indices, "All"),
|
||||
(radix_2to13, "Radix 2-13"),
|
||||
(bluesteins, "Bluestein's"),
|
||||
]:
|
||||
# plot bandwidths
|
||||
print(name)
|
||||
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
|
||||
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
|
||||
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
|
||||
plt.title(f"MLX FFT Benchmark -- {name}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"{name}.png")
|
||||
plt.clf()
|
||||
|
||||
av_gpu_bandwidth = np.mean(gpu_bandwidths)
|
||||
av_mps_bandwidth = np.mean(mps_bandwidths)
|
||||
av_cpu_bandwidth = np.mean(cpu_bandwidths)
|
||||
print("Average bandwidths:")
|
||||
print("GPU:", av_gpu_bandwidth)
|
||||
print("MPS:", av_mps_bandwidth)
|
||||
print("CPU:", av_cpu_bandwidth)
|
||||
|
||||
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
|
||||
print("Percent MLX faster than MPS: ", portion_faster * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
70
benchmarks/python/hadamard_bench.py
Normal file
70
benchmarks/python/hadamard_bench.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import argparse
|
||||
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def had(x):
|
||||
y = mx.hadamard_transform(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def copy(x):
|
||||
y = x + 1.0
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def run(dtype):
|
||||
system_size = 2**26
|
||||
outputs = {}
|
||||
for test_fn in (had, copy):
|
||||
for m in [1, 12, 20, 28]:
|
||||
if test_fn == copy:
|
||||
key = "copy"
|
||||
elif m == 1:
|
||||
key = "had_2^k"
|
||||
else:
|
||||
key = "had_m*2^k"
|
||||
outputs.setdefault(key, {})
|
||||
for k in range(7, 14):
|
||||
n = m * 2**k
|
||||
if n > 2**15:
|
||||
continue
|
||||
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
||||
x = mx.array(x_np)
|
||||
runtime_ms = measure_runtime(test_fn, x=x)
|
||||
bytes_per_gb = 1e9
|
||||
ms_per_s = 1e3
|
||||
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
||||
bandwidth_gb = (
|
||||
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
||||
)
|
||||
print(n, bandwidth_gb)
|
||||
outputs[key][n] = bandwidth_gb
|
||||
|
||||
colors = {
|
||||
"copy": "black",
|
||||
"had_2^k": "steelblue",
|
||||
"had_m*2^k": "skyblue",
|
||||
}
|
||||
for key, output in outputs.items():
|
||||
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
||||
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"bench_{dtype.__name__}.png")
|
||||
plt.clf()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
args = parser.parse_args()
|
||||
dtype = np.float16 if args.fp16 else np.float32
|
||||
run(dtype)
|
62
benchmarks/python/sdpa_bench.py
Normal file
62
benchmarks/python/sdpa_bench.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
81
benchmarks/python/sdpa_vector_bench.py
Normal file
81
benchmarks/python/sdpa_vector_bench.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 65536
|
||||
H = 32
|
||||
H_k = 32 // 4
|
||||
D = 128
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
|
||||
|
||||
|
||||
def quant_sdpa(q, k, v, bits=4):
|
||||
return mx.fast.quantized_scaled_dot_product_attention(
|
||||
q, *k, *v, scale=1.0, mask=None, bits=bits
|
||||
)
|
||||
|
||||
|
||||
def quant_attention(q, k, v, bits=4):
|
||||
B, Hq, L, D = q.shape
|
||||
Hk = k[0].shape[1]
|
||||
|
||||
q = q.reshape((B, Hk, Hq // Hk, L, D))
|
||||
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
|
||||
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
|
||||
|
||||
scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits)
|
||||
out = out.reshape((B, Hq, L, D))
|
||||
return out
|
||||
|
||||
|
||||
def time_self_attention_primitives(q, k, v):
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa(q, k, v):
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_quant_sdpa(q, k, v, bits=4):
|
||||
time_fn(quant_sdpa, q, k, v, bits)
|
||||
|
||||
|
||||
def time_self_attention_quant_primitives(q, k, v, bits=4):
|
||||
time_fn(quant_attention, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
bits = 4
|
||||
k_quant = mx.quantize(k, bits=bits)
|
||||
v_quant = mx.quantize(v, bits=bits)
|
||||
mx.eval(k_quant, v_quant)
|
||||
|
||||
time_self_attention_sdpa(q, k, v)
|
||||
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
|
||||
time_self_attention_primitives(q, k, v)
|
||||
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
|
@@ -1,56 +1,41 @@
|
||||
include(CMakeParseArguments)
|
||||
|
||||
###############################################################################
|
||||
# ##############################################################################
|
||||
# Build metal library
|
||||
#
|
||||
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
||||
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
||||
#
|
||||
# Args:
|
||||
# TARGET: Custom target to be added for the metal library
|
||||
# TITLE: Name of the .metallib
|
||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
||||
# SOURCES: List of source files
|
||||
# INCLUDE_DIRS: List of include dirs
|
||||
# DEPS: List of dependency files (like headers)
|
||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers)
|
||||
#
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||
cmake_parse_arguments(
|
||||
MTLLIB
|
||||
""
|
||||
"${oneValueArgs}"
|
||||
"${multiValueArgs}"
|
||||
${ARGN}
|
||||
)
|
||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
# Set output
|
||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||
|
||||
# Collect compile options
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS}
|
||||
${MTLLIB_SOURCES}
|
||||
-o ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND
|
||||
xcrun -sdk macosx metal
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
COMMAND_EXPAND_LISTS
|
||||
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
||||
VERBATIM
|
||||
)
|
||||
VERBATIM)
|
||||
|
||||
# Add metallib custom target
|
||||
add_custom_target(
|
||||
${MTLLIB_TARGET}
|
||||
DEPENDS
|
||||
${MTLLIB_BUILD_TARGET}
|
||||
)
|
||||
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
|
||||
|
||||
endmacro(mlx_build_metallib)
|
||||
endmacro(mlx_build_metallib)
|
||||
|
@@ -1,36 +0,0 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
|
||||
@@ -1906,6 +1906,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
@@ -1,36 +0,0 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
|
||||
@@ -1918,6 +1918,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
@@ -1,3 +1,4 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
||||
mlx
|
||||
|
@@ -83,3 +83,15 @@ def setup(app):
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||
latex_elements = {
|
||||
"preamble": r"""
|
||||
\usepackage{enumitem}
|
||||
\setlistdepth{5}
|
||||
\setlist[itemize,1]{label=$\bullet$}
|
||||
\setlist[itemize,2]{label=$\bullet$}
|
||||
\setlist[itemize,3]{label=$\bullet$}
|
||||
\setlist[itemize,4]{label=$\bullet$}
|
||||
\setlist[itemize,5]{label=$\bullet$}
|
||||
\renewlist{itemize}{itemize}{5}
|
||||
""",
|
||||
}
|
||||
|
421
docs/src/dev/custom_metal_kernels.rst
Normal file
421
docs/src/dev/custom_metal_kernels.rst
Normal file
@@ -0,0 +1,421 @@
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
.. note::
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
* The shapes/dtypes of ``inputs``
|
||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||
so we will add ``const device float16_t* inp`` to the signature.
|
||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||
in ``source``.
|
||||
* The list of ``output_dtypes``
|
||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||
so we add ``device float16_t* out``.
|
||||
* Template parameters passed using ``template``
|
||||
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
||||
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||
These will be added as function arguments.
|
||||
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
|
||||
|
||||
Putting this all together, the generated function signature for ``myexp`` is as follows:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void custom_kernel_myexp_float(
|
||||
const device float16_t* inp [[buffer(0)]],
|
||||
device float16_t* out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
|
||||
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
|
||||
}
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
# make non-contiguous
|
||||
a = a[::2]
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Complex Example
|
||||
-----------------------------
|
||||
|
||||
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
|
||||
|
||||
We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
``55.7ms -> 6.7ms => 8x speed up``
|
||||
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
|
||||
* ``atomic_outputs=True``
|
||||
Designate all of the kernel outputs as ``atomic`` in the function signature.
|
||||
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
|
||||
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
|
||||
|
||||
We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
``676.4ms -> 16.7ms => 40x speed up``
|
@@ -1,5 +1,5 @@
|
||||
Developer Documentation
|
||||
=======================
|
||||
Custom Extensions in MLX
|
||||
========================
|
||||
|
||||
You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||
explains how to do that with a simple example.
|
||||
@@ -486,15 +486,14 @@ below.
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
@@ -503,11 +502,11 @@ below.
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, y, 1);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
@@ -531,7 +530,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
@@ -825,7 +824,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correctness: {mx.all(c == 6.0).item()}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
|
@@ -15,7 +15,7 @@ module to concisely define the model architecture.
|
||||
Attention layer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
We will start with the llama attention layer which notably uses the RoPE
|
||||
We will start with the Llama attention layer which notably uses the RoPE
|
||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||
key/value cache that will be concatenated with the provided keys and values to
|
||||
support efficient inference.
|
||||
|
@@ -64,7 +64,7 @@ set:
|
||||
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||
`mnist data loader
|
||||
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||
we will import as `mnist`.
|
||||
we will import as ``mnist``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@@ -43,6 +43,7 @@ are the CPU and GPU.
|
||||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
|
||||
.. toctree::
|
||||
@@ -69,6 +70,7 @@ are the CPU and GPU.
|
||||
python/metal
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
python/tree_utils
|
||||
|
||||
.. toctree::
|
||||
@@ -83,3 +85,4 @@ are the CPU and GPU.
|
||||
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
dev/custom_metal_kernels
|
||||
|
@@ -14,7 +14,7 @@ silicon computer is
|
||||
To install from PyPI you must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.8
|
||||
- Using a native Python >= 3.9
|
||||
- macOS >= 13.5
|
||||
|
||||
.. note::
|
||||
@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
|
||||
Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
||||
|
||||
For developing use an editable install:
|
||||
For developing, install the package with development dependencies, and use an
|
||||
editable install:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
||||
|
||||
To make sure the install is working run the tests with:
|
||||
Once the development dependencies are installed, you can build faster with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
||||
|
||||
Run the tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[testing]"
|
||||
python -m unittest discover python/tests
|
||||
|
||||
Optional: Install stubs to enable auto completions and type checking from your IDE:
|
||||
Optional: Install stubs to enable auto completions and type checking from your
|
||||
IDE:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[dev]"
|
||||
python setup.py generate_stubs
|
||||
|
||||
C++ API
|
||||
@@ -153,11 +153,18 @@ should point to the path to the built metal library.
|
||||
- OFF
|
||||
* - MLX_BUILD_METAL
|
||||
- ON
|
||||
* - MLX_BUILD_CPU
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
* - MLX_METAL_DEBUG
|
||||
- OFF
|
||||
|
||||
* - MLX_BUILD_SAFETENSORS
|
||||
- ON
|
||||
* - MLX_BUILD_GGUF
|
||||
- ON
|
||||
* - MLX_METAL_JIT
|
||||
- OFF
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -176,10 +183,37 @@ should point to the path to the built metal library.
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||
and ``BUILD_SHARED_LIBS=ON``.
|
||||
|
||||
The MLX CMake build has several additional options to make smaller binaries.
|
||||
For example, if you don't need the CPU backend or support for safetensors and
|
||||
GGUF, you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists accross reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Metal not found
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -206,7 +240,7 @@ x86 Shell
|
||||
|
||||
.. _build shell:
|
||||
|
||||
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||
Rosetta instead of natively.
|
||||
|
||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||
@@ -230,4 +264,4 @@ Also check that cmake is using the correct architecture:
|
||||
|
||||
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||
wipe your build cahce with ``rm -rf build/`` and try again.
|
||||
wipe your build cache with ``rm -rf build/`` and try again.
|
||||
|
@@ -24,6 +24,7 @@ Array
|
||||
array.any
|
||||
array.argmax
|
||||
array.argmin
|
||||
array.conj
|
||||
array.cos
|
||||
array.cummax
|
||||
array.cummin
|
||||
@@ -52,8 +53,10 @@ Array
|
||||
array.sqrt
|
||||
array.square
|
||||
array.squeeze
|
||||
array.swapaxes
|
||||
array.std
|
||||
array.sum
|
||||
array.swapaxes
|
||||
array.transpose
|
||||
array.T
|
||||
array.var
|
||||
array.view
|
||||
|
22
docs/src/python/distributed.rst
Normal file
22
docs/src/python/distributed.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
.. _distributed:
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Distributed Communication
|
||||
==========================
|
||||
|
||||
MLX provides a distributed communication package using MPI. The MPI library is
|
||||
loaded at runtime; if MPI is available then distributed communication is also
|
||||
made available.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Group
|
||||
is_available
|
||||
init
|
||||
all_sum
|
||||
all_gather
|
||||
send
|
||||
recv
|
||||
recv_like
|
@@ -12,3 +12,5 @@ Fast
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
affine_quantize
|
||||
metal_kernel
|
||||
|
@@ -8,5 +8,13 @@ Linear Algebra
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
tri_inv
|
||||
norm
|
||||
cholesky
|
||||
cholesky_inv
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
eigvalsh
|
||||
eigh
|
||||
|
@@ -10,9 +10,11 @@ Metal
|
||||
device_info
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -13,10 +13,13 @@ simple functions.
|
||||
:template: nn-module-template.rst
|
||||
|
||||
elu
|
||||
celu
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
glu
|
||||
hard_shrink
|
||||
hard_tanh
|
||||
hardswish
|
||||
leaky_relu
|
||||
log_sigmoid
|
||||
@@ -29,6 +32,7 @@ simple functions.
|
||||
sigmoid
|
||||
silu
|
||||
softmax
|
||||
softmin
|
||||
softplus
|
||||
softshrink
|
||||
step
|
||||
|
@@ -13,18 +13,31 @@ Layers
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
BatchNorm
|
||||
CELU
|
||||
Conv1d
|
||||
Conv2d
|
||||
Conv3d
|
||||
ConvTranspose1d
|
||||
ConvTranspose2d
|
||||
ConvTranspose3d
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
Embedding
|
||||
ELU
|
||||
GELU
|
||||
GLU
|
||||
GroupNorm
|
||||
GRU
|
||||
HardShrink
|
||||
HardTanh
|
||||
Hardswish
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
LeakyReLU
|
||||
Linear
|
||||
LogSigmoid
|
||||
LogSoftmax
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
@@ -35,13 +48,20 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
SELU
|
||||
Sequential
|
||||
Sigmoid
|
||||
SiLU
|
||||
SinusoidalPositionalEncoding
|
||||
Softmin
|
||||
Softshrink
|
||||
Softsign
|
||||
Softmax
|
||||
Softplus
|
||||
Step
|
||||
Tanh
|
||||
Transformer
|
||||
Upsample
|
||||
|
@@ -10,6 +10,7 @@ Operations
|
||||
|
||||
abs
|
||||
add
|
||||
addmm
|
||||
all
|
||||
allclose
|
||||
any
|
||||
@@ -19,12 +20,14 @@ Operations
|
||||
arcsin
|
||||
arcsinh
|
||||
arctan
|
||||
arctan2
|
||||
arctanh
|
||||
argmax
|
||||
argmin
|
||||
argpartition
|
||||
argsort
|
||||
array_equal
|
||||
as_strided
|
||||
atleast_1d
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
@@ -32,14 +35,19 @@ Operations
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
block_sparse_mm
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
conv3d
|
||||
conv_transpose1d
|
||||
conv_transpose2d
|
||||
conv_transpose3d
|
||||
conv_general
|
||||
cos
|
||||
cosh
|
||||
@@ -53,6 +61,8 @@ Operations
|
||||
diagonal
|
||||
divide
|
||||
divmod
|
||||
einsum
|
||||
einsum_path
|
||||
equal
|
||||
erf
|
||||
erfinv
|
||||
@@ -64,15 +74,21 @@ Operations
|
||||
floor
|
||||
floor_divide
|
||||
full
|
||||
gather_mm
|
||||
gather_qmm
|
||||
greater
|
||||
greater_equal
|
||||
hadamard_transform
|
||||
identity
|
||||
imag
|
||||
inner
|
||||
isfinite
|
||||
isclose
|
||||
isinf
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
issubdtype
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
@@ -96,6 +112,7 @@ Operations
|
||||
minimum
|
||||
moveaxis
|
||||
multiply
|
||||
nan_to_num
|
||||
negative
|
||||
not_equal
|
||||
ones
|
||||
@@ -103,14 +120,19 @@ Operations
|
||||
outer
|
||||
partition
|
||||
pad
|
||||
power
|
||||
prod
|
||||
put_along_axis
|
||||
quantize
|
||||
quantized_matmul
|
||||
radians
|
||||
real
|
||||
reciprocal
|
||||
remainder
|
||||
repeat
|
||||
reshape
|
||||
right_shift
|
||||
roll
|
||||
round
|
||||
rsqrt
|
||||
save
|
||||
@@ -141,11 +163,13 @@ Operations
|
||||
tensordot
|
||||
tile
|
||||
topk
|
||||
trace
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
var
|
||||
view
|
||||
where
|
||||
zeros
|
||||
zeros_like
|
||||
|
@@ -1,5 +1,7 @@
|
||||
.. _optimizers:
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
Optimizers
|
||||
==========
|
||||
|
||||
@@ -29,8 +31,48 @@ model's parameters and the **optimizer state**.
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
Saving and Loading
|
||||
------------------
|
||||
|
||||
To serialize an optimizer, save its state. To load an optimizer, load and set
|
||||
the saved state. Here's a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
import mlx.optimizers as optim
|
||||
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
# Perform some updates with the optimizer
|
||||
model = {"w" : mx.zeros((5, 5))}
|
||||
grads = {"w" : mx.ones((5, 5))}
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
|
||||
parameters are not. A good rule of thumb is if the parameter can be scheduled
|
||||
then it will be included in the optimizer state.
|
||||
|
||||
.. toctree::
|
||||
|
||||
optimizers/optimizer
|
||||
optimizers/common_optimizers
|
||||
optimizers/schedulers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
clip_grad_norm
|
||||
|
@@ -44,3 +44,5 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
split
|
||||
truncated_normal
|
||||
uniform
|
||||
laplace
|
||||
permutation
|
||||
|
@@ -10,6 +10,7 @@ Transforms
|
||||
|
||||
eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
enable_compile
|
||||
grad
|
||||
|
@@ -20,3 +20,4 @@ return python trees will be using the default python ``dict``, ``list`` and
|
||||
tree_unflatten
|
||||
tree_map
|
||||
tree_map_with_path
|
||||
tree_reduce
|
||||
|
@@ -33,12 +33,12 @@ Let's start with a simple example:
|
||||
# Compile the function
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
The output of both the regular function and the compiled function is the same
|
||||
up to numerical precision.
|
||||
|
||||
|
||||
The first time you call a compiled function, MLX will build the compute
|
||||
graph, optimize it, and generate and compile code. This can be relatively
|
||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||
@@ -96,7 +96,7 @@ element-wise operations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def gelu(x):
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
If you use this function with small arrays, it will be overhead bound. If you
|
||||
@@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
|
||||
.. note::
|
||||
|
||||
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
||||
functions can still be helpful, but won't typically result in as large a
|
||||
speedup as compiling operations that run on the GPU.
|
||||
|
||||
|
||||
Debugging
|
||||
---------
|
||||
|
||||
@@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
|
||||
Compiling Training Graphs
|
||||
Compiling Training Graphs
|
||||
-------------------------
|
||||
|
||||
This section will step through how to use :func:`compile` with a simple example
|
||||
@@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
|
||||
|
||||
To start, here is the simple example without any compilation:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
|
||||
To compile the update we can put it all in a function and compile it with the
|
||||
appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
# The state that will be captured as input and output
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x, y):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
@@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
|
||||
|
||||
In order to compile as much as possible, a transformation of a compiled
|
||||
function will not by default be compiled. To compile the transformed
|
||||
function simply pass it through :func:`compile`.
|
||||
function simply pass it through :func:`compile`.
|
||||
|
||||
You can also compile functions which themselves call compiled functions. A
|
||||
good practice is to compile the outer most function to give :func:`compile`
|
||||
|
166
docs/src/usage/distributed.rst
Normal file
166
docs/src/usage/distributed.rst
Normal file
@@ -0,0 +1,166 @@
|
||||
.. _usage_distributed:
|
||||
|
||||
Distributed Communication
|
||||
=========================
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||
provide distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. You can
|
||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||
|
||||
.. note::
|
||||
A lot of operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. The minimal distributed program in MLX is as simple as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
world = mx.distributed.init()
|
||||
x = mx.distributed.all_sum(mx.ones(10))
|
||||
print(world.rank(), x)
|
||||
|
||||
The program above sums the array ``mx.ones(10)`` across all
|
||||
distributed processes. If simply run with ``python``, however, only one
|
||||
process is launched and no distributed communication takes place.
|
||||
|
||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||
following:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 python test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
---------------
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
-----------------------
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||
full path to force all machines to use a specific path.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
.. note::
|
||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||
|
||||
An easy way to pass the host names to MPI is using a host file. A host file
|
||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||
qualified domain names or IPs for these hosts.
|
||||
|
||||
.. code::
|
||||
|
||||
host1 slots=1
|
||||
host2 slots=1
|
||||
|
||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||
process per host. The hostfile also needs to contain the current
|
||||
host if you want to run on the local host. Passing the host file to
|
||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
|
||||
In this section we will adapt an MLX training loop to support data parallel
|
||||
distributed training. Namely, we will average the gradients across a set of
|
||||
hosts before applying them to the model.
|
||||
|
||||
Our training loop looks like the following code snippet if we omit the model,
|
||||
dataset and optimizer initialization.
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
All we have to do to average the gradients across machines is perform an
|
||||
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
|
||||
have to :func:`mlx.utils.tree_map` the gradients with following function.
|
||||
|
||||
.. code:: python
|
||||
|
||||
def all_avg(x):
|
||||
return mx.distributed.all_sum(x) / mx.distributed.init().size()
|
||||
|
||||
Putting everything together our training loop step looks as follows with
|
||||
everything else remaining the same.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = all_reduce_grads(grads) # <--- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
Tuning All Reduce
|
||||
-----------------
|
||||
|
||||
We are working on improving the performance of all reduce on MLX but for now
|
||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||
|
||||
1. Perform a few large reductions instead of many small ones to improve
|
||||
bandwidth and latency
|
||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||
connections between each host to improve bandwidth
|
@@ -25,7 +25,7 @@ Here is a simple example:
|
||||
|
||||
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
||||
case it is the gradient of the sine function which is exactly the cosine
|
||||
function. To get the second derivative you can do:
|
||||
function. To get the second derivative you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
@@ -50,7 +50,7 @@ Automatic Differentiation
|
||||
.. _auto diff:
|
||||
|
||||
Automatic differentiation in MLX works on functions rather than on implicit
|
||||
graphs.
|
||||
graphs.
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -114,7 +114,7 @@ way to do that is the following:
|
||||
|
||||
def loss_fn(params, x, y):
|
||||
w, b = params["weight"], params["bias"]
|
||||
h = w * x + b
|
||||
h = w * x + b
|
||||
return mx.mean(mx.square(h - y))
|
||||
|
||||
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
||||
@@ -132,7 +132,7 @@ way to do that is the following:
|
||||
|
||||
Notice the tree structure of the parameters is preserved in the gradients.
|
||||
|
||||
In some cases you may want to stop gradients from propagating through a
|
||||
In some cases you may want to stop gradients from propagating through a
|
||||
part of the function. You can use the :func:`stop_gradient` for that.
|
||||
|
||||
|
||||
@@ -166,14 +166,14 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
# Vectorize over the second dimension of x and the
|
||||
# first dimension of y
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||
|
||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||
where the vectorized axes should be in the outputs.
|
||||
where the vectorized axes should be in the outputs.
|
||||
|
||||
Let's time these two different versions:
|
||||
|
||||
|
@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(10)
|
||||
>>> idx = mx.array([5, 7])
|
||||
>>> idx = mx.array([5, 7])
|
||||
>>> arr[idx]
|
||||
array([5, 7], dtype=int32)
|
||||
|
||||
@@ -82,7 +82,7 @@ general, MLX has limited support for operations for which outputs
|
||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||
single input version of :func:`numpy.where`.
|
||||
|
||||
In Place Updates
|
||||
In Place Updates
|
||||
----------------
|
||||
|
||||
In place updates to indexed arrays are possible in MLX. For example:
|
||||
|
@@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
|
||||
:func:`eval` is performed.
|
||||
|
||||
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||
describe below.
|
||||
describe below.
|
||||
|
||||
Transforming Compute Graphs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -116,7 +116,7 @@ saving functions) will also evaluate the array.
|
||||
|
||||
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||
will be a partial evaluation, computing only the forward pass.
|
||||
|
||||
|
@@ -3,7 +3,11 @@
|
||||
Conversion to NumPy and Other Frameworks
|
||||
========================================
|
||||
|
||||
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
MLX array supports conversion between other frameworks with either:
|
||||
|
||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||
|
||||
Let's convert an array to NumPy and back.
|
||||
|
||||
.. code-block:: python
|
||||
@@ -62,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
|
||||
PyTorch
|
||||
-------
|
||||
|
||||
.. warning::
|
||||
.. warning::
|
||||
|
||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||
|
@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
|
||||
and :func:`jvp` for Jacobian-vector products.
|
||||
|
||||
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||
gradient with respect to the function's input.
|
||||
gradient with respect to the function's input.
|
||||
|
@@ -8,33 +8,33 @@ Saving and Loading Arrays
|
||||
MLX supports multiple array serialization formats.
|
||||
|
||||
.. list-table:: Serialization Formats
|
||||
:widths: 20 8 25 25
|
||||
:widths: 20 8 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Format
|
||||
- Extension
|
||||
* - Format
|
||||
- Extension
|
||||
- Function
|
||||
- Notes
|
||||
* - NumPy
|
||||
- ``.npy``
|
||||
- Notes
|
||||
* - NumPy
|
||||
- ``.npy``
|
||||
- :func:`save`
|
||||
- Single arrays only
|
||||
* - NumPy archive
|
||||
- ``.npz``
|
||||
* - NumPy archive
|
||||
- ``.npz``
|
||||
- :func:`savez` and :func:`savez_compressed`
|
||||
- Multiple arrays
|
||||
- Multiple arrays
|
||||
* - Safetensors
|
||||
- ``.safetensors``
|
||||
- ``.safetensors``
|
||||
- :func:`save_safetensors`
|
||||
- Multiple arrays
|
||||
* - GGUF
|
||||
- ``.gguf``
|
||||
- Multiple arrays
|
||||
* - GGUF
|
||||
- ``.gguf``
|
||||
- :func:`save_gguf`
|
||||
- Multiple arrays
|
||||
|
||||
The :func:`load` function will load any of the supported serialization
|
||||
formats. It determines the format from the extensions. The output of
|
||||
:func:`load` depends on the format.
|
||||
:func:`load` depends on the format.
|
||||
|
||||
Here's an example of saving a single array to a file:
|
||||
|
||||
|
@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
|
||||
|
||||
In MLX, rather than moving arrays to devices, you specify the device when you
|
||||
run the operation. Any device can perform any operation on ``a`` and ``b``
|
||||
without needing to move them from one memory location to another. For example:
|
||||
without needing to move them from one memory location to another. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@@ -9,3 +9,4 @@ build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
||||
build_example(metal_capture.cpp)
|
||||
build_example(distributed.cpp)
|
||||
|
22
examples/cpp/distributed.cpp
Normal file
22
examples/cpp/distributed.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
@@ -89,8 +89,8 @@ void automatic_differentiation() {
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto df2dx2 = grad(grad(fn))(x);
|
||||
// df2dx2 is 2
|
||||
auto d2fdx2 = grad(grad(fn))(x);
|
||||
// d2fdx2 is 2
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
@@ -11,10 +11,14 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
find_package(
|
||||
Python 3.8
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
@@ -24,16 +28,10 @@ find_package(nanobind CONFIG REQUIRED)
|
||||
add_library(mlx_ext)
|
||||
|
||||
# Add sources
|
||||
target_sources(
|
||||
mlx_ext
|
||||
PUBLIC
|
||||
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
||||
)
|
||||
target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
|
||||
|
||||
# Add include headers
|
||||
target_include_directories(
|
||||
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Link to mlx
|
||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
@@ -43,27 +41,32 @@ target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
# Build metallib
|
||||
if(MLX_BUILD_METAL)
|
||||
mlx_build_metallib(
|
||||
TARGET mlx_ext_metallib
|
||||
TITLE mlx_ext
|
||||
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
||||
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
add_dependencies(
|
||||
mlx_ext
|
||||
TARGET
|
||||
mlx_ext_metallib
|
||||
)
|
||||
TITLE
|
||||
mlx_ext
|
||||
SOURCES
|
||||
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||
INCLUDE_DIRS
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${MLX_INCLUDE_DIRS}
|
||||
OUTPUT_DIRECTORY
|
||||
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||
|
||||
add_dependencies(mlx_ext mlx_ext_metallib)
|
||||
|
||||
endif()
|
||||
|
||||
# ----------------------------- Python Bindings -----------------------------
|
||||
nanobind_add_module(
|
||||
_ext
|
||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||
NB_DOMAIN mlx
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
NB_STATIC
|
||||
STABLE_ABI
|
||||
LTO
|
||||
NOMINSIZE
|
||||
NB_DOMAIN
|
||||
mlx
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
|
||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
|
||||
## Build the extensions
|
||||
## Build
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
@@ -16,3 +16,9 @@ And then run:
|
||||
```
|
||||
python setup.py build_ext -j8 --inplace
|
||||
```
|
||||
|
||||
## Test
|
||||
|
||||
```
|
||||
python test.py
|
||||
```
|
||||
|
@@ -249,15 +249,14 @@ void Axpby::eval_gpu(
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
@@ -266,11 +265,11 @@ void Axpby::eval_gpu(
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, y, 1);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
@@ -296,7 +295,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
|
@@ -2,4 +2,4 @@
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .mlx_sample_extensions import *
|
||||
from ._ext import axpby
|
||||
|
@@ -2,7 +2,7 @@
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"mlx>=0.9.0",
|
||||
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
|
||||
"mlx>=0.18.0",
|
||||
"nanobind==2.2.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.9.0
|
||||
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
mlx>=0.18.1
|
||||
nanobind==2.2.0
|
||||
|
@@ -13,7 +13,6 @@ if __name__ == "__main__":
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev": []},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
10
examples/extensions/test.py
Normal file
10
examples/extensions/test.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
@@ -1,37 +1,40 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
|
||||
)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
elseif(MLX_BUILD_CPU)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
|
@@ -23,11 +23,22 @@ void free(Buffer buffer) {
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
return Buffer{std::malloc(size)};
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
std::free(buffer.raw_ptr());
|
||||
std::free(buffer.ptr());
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
Buffer malloc_or_wait(size_t size) {
|
||||
|
@@ -41,6 +41,7 @@ class Allocator {
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
|
||||
Allocator() = default;
|
||||
Allocator(const Allocator& other) = delete;
|
||||
@@ -57,6 +58,7 @@ class CommonAllocator : public Allocator {
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
|
@@ -17,6 +17,10 @@ bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@@ -91,18 +95,34 @@ void array::detach() {
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
void array::eval() {
|
||||
// Ensure the array is ready to be read
|
||||
if (status() == Status::scheduled) {
|
||||
bool array::is_available() const {
|
||||
if (status() == Status::available) {
|
||||
return true;
|
||||
} else if (status() == Status::evaluated && event().is_signaled()) {
|
||||
set_status(Status::available);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void array::wait() {
|
||||
if (!is_available()) {
|
||||
event().wait();
|
||||
set_status(Status::available);
|
||||
} else if (status() == Status::unscheduled) {
|
||||
}
|
||||
}
|
||||
|
||||
void array::eval() {
|
||||
// Ensure the array is ready to be read
|
||||
if (status() == Status::unscheduled) {
|
||||
mlx::core::eval({*this});
|
||||
} else {
|
||||
wait();
|
||||
}
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
@@ -158,8 +178,10 @@ void array::move_shared_buffer(
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
array_desc_->data_ptr = static_cast<void*>(
|
||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||
auto data_ptr = other.array_desc_->data_ptr;
|
||||
other.array_desc_->data_ptr = nullptr;
|
||||
array_desc_->data_ptr =
|
||||
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
@@ -171,10 +193,11 @@ array::~array() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore arrays that will be detached
|
||||
if (status() != array::Status::unscheduled) {
|
||||
// Ignore arrays that might be detached during eval
|
||||
if (status() == array::Status::scheduled) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Break circular reference for non-detached arrays with siblings
|
||||
if (auto n = siblings().size(); n > 0) {
|
||||
bool do_detach = true;
|
||||
@@ -206,7 +229,7 @@ void array::ArrayDesc::init() {
|
||||
strides[i] = size;
|
||||
size *= shape[i];
|
||||
}
|
||||
for (auto& in : inputs) {
|
||||
for (const auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
}
|
||||
@@ -231,31 +254,41 @@ array::ArrayDesc::ArrayDesc(
|
||||
|
||||
array::ArrayDesc::~ArrayDesc() {
|
||||
// When an array description is destroyed it will delete a bunch of arrays
|
||||
// that may also destory their corresponding descriptions and so on and so
|
||||
// that may also destroy their corresponding descriptions and so on and so
|
||||
// forth.
|
||||
//
|
||||
// This calls recursively the destructor and can result in stack overflow, we
|
||||
// instead put them in a vector and destroy them one at a time resulting in a
|
||||
// max stack depth of 2.
|
||||
if (inputs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
||||
|
||||
for (array& a : inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
|
||||
std::unordered_map<std::uintptr_t, array> input_map;
|
||||
for (array& a : ad.inputs) {
|
||||
if (a.array_desc_) {
|
||||
input_map.insert({a.id(), a});
|
||||
}
|
||||
}
|
||||
}
|
||||
ad.inputs.clear();
|
||||
for (auto& [_, a] : input_map) {
|
||||
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
append_deletable_inputs(*this);
|
||||
|
||||
while (!for_deletion.empty()) {
|
||||
// top is going to be deleted at the end of the block *after* the arrays
|
||||
// with inputs have been moved into the vector
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
|
||||
for (array& a : top->inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
append_deletable_inputs(*top);
|
||||
}
|
||||
}
|
||||
|
||||
|
114
mlx/array.h
114
mlx/array.h
@@ -73,32 +73,32 @@ class array {
|
||||
this->array_desc_ = other.array_desc_;
|
||||
}
|
||||
return *this;
|
||||
};
|
||||
}
|
||||
|
||||
/** The size of the array's datatype in bytes. */
|
||||
size_t itemsize() const {
|
||||
return size_of(dtype());
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of elements in the array. */
|
||||
size_t size() const {
|
||||
return array_desc_->size;
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of bytes in the array. */
|
||||
size_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of dimensions of the array. */
|
||||
size_t ndim() const {
|
||||
return array_desc_->shape.size();
|
||||
};
|
||||
}
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const std::vector<int>& shape() const {
|
||||
return array_desc_->shape;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the size of the corresponding dimension.
|
||||
@@ -107,12 +107,12 @@ class array {
|
||||
* bounds checking. */
|
||||
int shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
}
|
||||
|
||||
/** The strides of the array. */
|
||||
const std::vector<size_t>& strides() const {
|
||||
return array_desc_->strides;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the stride of the corresponding dimension.
|
||||
@@ -121,12 +121,12 @@ class array {
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
}
|
||||
|
||||
/** Get the arrays data type. */
|
||||
Dtype dtype() const {
|
||||
return array_desc_->dtype;
|
||||
};
|
||||
}
|
||||
|
||||
/** Evaluate the array. */
|
||||
void eval();
|
||||
@@ -160,10 +160,10 @@ class array {
|
||||
|
||||
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return a.arr.id() == b.arr.id() && a.idx == b.idx;
|
||||
};
|
||||
}
|
||||
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return !(a == b);
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
const array& arr;
|
||||
@@ -209,7 +209,7 @@ class array {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
: buffer(buffer), d(d) {};
|
||||
: buffer(buffer), d(d) {}
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
@@ -219,33 +219,45 @@ class array {
|
||||
};
|
||||
|
||||
struct Flags {
|
||||
// True if there are no gaps in the underlying data. Each item
|
||||
// True iff there are no gaps in the underlying data. Each item
|
||||
// in the underlying data buffer belongs to at least one index.
|
||||
//
|
||||
// True iff:
|
||||
// prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
|
||||
bool contiguous : 1;
|
||||
|
||||
// True iff:
|
||||
// strides[-1] == 1 and
|
||||
// all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
|
||||
// range(ndim - 1))
|
||||
bool row_contiguous : 1;
|
||||
|
||||
// True iff:
|
||||
// strides[0] == 1 and
|
||||
// all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
|
||||
// range(1, ndim))
|
||||
bool col_contiguous : 1;
|
||||
};
|
||||
|
||||
/** The array's primitive. */
|
||||
Primitive& primitive() const {
|
||||
return *(array_desc_->primitive);
|
||||
};
|
||||
}
|
||||
|
||||
/** A shared pointer to the array's primitive. */
|
||||
std::shared_ptr<Primitive>& primitive_ptr() const {
|
||||
return array_desc_->primitive;
|
||||
};
|
||||
}
|
||||
|
||||
/** Check if the array has an attached primitive or is a leaf node. */
|
||||
bool has_primitive() const {
|
||||
return array_desc_->primitive != nullptr;
|
||||
};
|
||||
}
|
||||
|
||||
/** The array's inputs. */
|
||||
const std::vector<array>& inputs() const {
|
||||
return array_desc_->inputs;
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array>& inputs() {
|
||||
return array_desc_->inputs;
|
||||
@@ -259,12 +271,12 @@ class array {
|
||||
/** The array's siblings. */
|
||||
const std::vector<array>& siblings() const {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
std::vector<array>& siblings() {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
}
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
@@ -281,7 +293,7 @@ class array {
|
||||
outputs.push_back(*this);
|
||||
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
|
||||
return outputs;
|
||||
};
|
||||
}
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
@@ -289,19 +301,32 @@ class array {
|
||||
/** Get the Flags bit-field. */
|
||||
const Flags& flags() const {
|
||||
return array_desc_->flags;
|
||||
};
|
||||
}
|
||||
|
||||
/** The size (in elements) of the underlying buffer the array points to. */
|
||||
/** The size (in elements) of the underlying buffer the array points to.
|
||||
*
|
||||
* This can be different than the actual size of the array if the array has
|
||||
* been broadcast or irregularly strided. If ``first`` is the offset into
|
||||
* the data buffer of the first element of the array (i.e. the offset
|
||||
* corresponding to ``arr[0, 0, ...]``) and last is the offset into the
|
||||
* data buffer of the last element of the array (i.e. the offset
|
||||
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
|
||||
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
|
||||
**/
|
||||
size_t data_size() const {
|
||||
return array_desc_->data_size;
|
||||
};
|
||||
}
|
||||
|
||||
allocator::Buffer& buffer() {
|
||||
return array_desc_->data->buffer;
|
||||
};
|
||||
}
|
||||
const allocator::Buffer& buffer() const {
|
||||
return array_desc_->data->buffer;
|
||||
};
|
||||
}
|
||||
|
||||
size_t buffer_size() const {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
@@ -312,19 +337,42 @@ class array {
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
}
|
||||
|
||||
enum Status {
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
// The ouptut of a computation which has been scheduled but `eval_*` has
|
||||
// not yet been called on the array's primitive. A possible
|
||||
// status of `x` in `auto x = a + b; eval(x);`
|
||||
scheduled,
|
||||
|
||||
// The array's `eval_*` function has been run, but the computation is not
|
||||
// necessarily complete. The array will have memory allocated and if it is
|
||||
// not a tracer then it will be detached from the graph.
|
||||
evaluated,
|
||||
|
||||
// If the array is the output of a computation then the computation
|
||||
// is complete. Constant arrays are always available (e.g. `array({1, 2,
|
||||
// 3})`)
|
||||
available
|
||||
};
|
||||
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
// Check if the array is safe to read.
|
||||
bool is_available() const;
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
const Status status() const {
|
||||
// Wait on the array to be available. After this `is_available` returns
|
||||
// `true`.
|
||||
void wait();
|
||||
|
||||
Status status() const {
|
||||
return array_desc_->status;
|
||||
}
|
||||
|
||||
@@ -411,8 +459,6 @@ class array {
|
||||
void* data_ptr{nullptr};
|
||||
|
||||
// The size in elements of the data buffer the array accesses
|
||||
// This can be different than the actual size of the array if it
|
||||
// has been broadcast or irregularly strided.
|
||||
size_t data_size;
|
||||
|
||||
// Contains useful meta data about the array
|
||||
|
@@ -1,10 +1,8 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
|
||||
|
@@ -1,9 +1,9 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
#include <vecLib/vDSP.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@@ -2,8 +2,7 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <vecLib/BNNS/bnns.h>
|
||||
#include <vecLib/cblas_new.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/backend/accelerate/utils.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
|
@@ -3,8 +3,7 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <vecLib/vDSP.h>
|
||||
#include <vecLib/vForce.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
@@ -32,12 +31,12 @@ DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(BlockSparseMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
@@ -47,8 +46,11 @@ DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
@@ -78,6 +80,8 @@ DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
@@ -99,7 +103,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -114,7 +118,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary(
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -129,7 +133,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,6 +197,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
if (a.is_donatable()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
int size = a.data_size();
|
||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -264,7 +288,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == int32) {
|
||||
binary(
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -277,7 +301,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -292,7 +316,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,12 +327,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[exp] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,12 +390,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log1p] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,7 +401,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -400,7 +416,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,7 +427,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return -x; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -498,7 +514,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return x * x; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -524,7 +540,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -542,7 +558,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary(
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -554,7 +570,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
},
|
||||
UseDefaultBinaryOp());
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -18,49 +18,61 @@ void _qmm_t_4_64(
|
||||
const float* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
int K,
|
||||
int B,
|
||||
bool batched_w) {
|
||||
constexpr int bits = 4;
|
||||
constexpr int group_size = 64;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
const float* biases_local = biases;
|
||||
int w_els = N * K / pack_factor;
|
||||
int g_els = w_els * pack_factor / group_size;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
for (int i = 0; i < B; i++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
const float* biases_local = biases;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||
// TODO: vectorize this properly
|
||||
simd_uint16 wi;
|
||||
for (int e = 0; e < 2; e++) {
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= bits;
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||
// TODO: vectorize this properly
|
||||
simd_uint16 wi;
|
||||
for (int e = 0; e < 2; e++) {
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
simd_float16 wf = simd_float(wi);
|
||||
wf *= scale;
|
||||
wf += bias;
|
||||
simd_float16 wf = simd_float(wi);
|
||||
wf *= scale;
|
||||
wf += bias;
|
||||
|
||||
sum += (*x_local) * wf;
|
||||
x_local++;
|
||||
sum += (*x_local) * wf;
|
||||
x_local++;
|
||||
}
|
||||
}
|
||||
|
||||
*result = simd_reduce_add(sum);
|
||||
result++;
|
||||
}
|
||||
|
||||
*result = simd_reduce_add(sum);
|
||||
result++;
|
||||
x += K;
|
||||
}
|
||||
if (batched_w) {
|
||||
w += w_els;
|
||||
scales += g_els;
|
||||
biases += g_els;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,8 +94,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (condition) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
int B = x.size() / K / M;
|
||||
bool batched_w = w.ndim() > 2;
|
||||
_qmm_t_4_64(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
@@ -92,7 +106,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
K,
|
||||
B,
|
||||
batched_w);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
@@ -2,8 +2,8 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
#include <vecLib/vDSP.h>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@@ -3,7 +3,10 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
@@ -30,8 +33,8 @@ namespace {
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||
x *= 1.442695; // multiply with log_2(e)
|
||||
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
|
||||
auto x = x_init * 1.442695; // multiply with log_2(e)
|
||||
simd_float16 ipart, fpart;
|
||||
simd_int16 epart;
|
||||
x = simd_clamp(x, -80, 80);
|
||||
@@ -50,28 +53,30 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||
// bitshifting
|
||||
epart = (simd_int(ipart) + 127) << 23;
|
||||
|
||||
return (*(simd_float16*)&epart) * x;
|
||||
// Avoid supressing NaNs
|
||||
simd_int16 eq = (x_init == x_init);
|
||||
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
|
||||
}
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
/**
|
||||
* The ARM neon equivalent of the fast exp above.
|
||||
*/
|
||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
|
||||
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
||||
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||
|
||||
x = vdupq_n_f16(1.535336188319500e-4f);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
|
||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
@@ -107,53 +112,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return *(VT*)a;
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
*(VT*)dst = x;
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
};
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return simd_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return simd_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct NeonFp16SimdOps {
|
||||
VT init(T a) {
|
||||
@@ -170,7 +128,7 @@ struct NeonFp16SimdOps {
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
};
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
@@ -201,6 +159,55 @@ struct NeonFp16SimdOps {
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return *(VT*)a;
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
*(VT*)dst = x;
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return simd_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return simd_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
@@ -362,12 +369,16 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
eval(inputs, out); // Redirect to common backend for consistency
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
|
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vecLib/BNNS/bnns.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@@ -1,5 +1,4 @@
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(COMPILER ${CMAKE_C_COMPILER})
|
||||
set(CLANG TRUE)
|
||||
else()
|
||||
@@ -7,67 +6,57 @@ else()
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${CLANG}
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h)
|
||||
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
cpu_compiled_preamble
|
||||
DEPENDS compiled_preamble.cpp
|
||||
)
|
||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||
|
||||
add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if (IOS)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
|
||||
)
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
|
||||
)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
|
||||
endif()
|
||||
|
@@ -196,6 +196,20 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
@@ -293,4 +307,25 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[arctan2] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@@ -41,13 +43,15 @@ void set_binary_op_output_data(
|
||||
array& out,
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
bool b_donatable = is_donatable(b, out);
|
||||
bool a_donatable = is_donatable(a, out);
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
@@ -62,7 +66,7 @@ void set_binary_op_output_data(
|
||||
}
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
@@ -77,13 +81,13 @@ void set_binary_op_output_data(
|
||||
}
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
} else if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
@@ -98,16 +102,14 @@ void set_binary_op_output_data(
|
||||
}
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
if (a.is_donatable() && a.flags().row_contiguous &&
|
||||
a.itemsize() == out.itemsize() && a.size() == out.size()) {
|
||||
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else if (
|
||||
b.is_donatable() && b.flags().row_contiguous &&
|
||||
b.itemsize() == out.itemsize() && b.size() == out.size()) {
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
@@ -120,19 +122,7 @@ void set_binary_op_output_data(
|
||||
}
|
||||
}
|
||||
|
||||
struct UseDefaultBinaryOp {
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
// Should we throw? This should normally never be called.
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
// Should we throw? This should normally never be called.
|
||||
assert(false);
|
||||
}
|
||||
};
|
||||
struct UseDefaultBinaryOp {};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorScalar {
|
||||
@@ -148,18 +138,6 @@ struct DefaultVectorScalar {
|
||||
a++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
T scalar = *b;
|
||||
while (size-- > 0) {
|
||||
auto dst = op(*a, scalar);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@@ -176,18 +154,6 @@ struct DefaultScalarVector {
|
||||
b++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
T scalar = *a;
|
||||
while (size-- > 0) {
|
||||
auto dst = op(scalar, *b);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@@ -204,204 +170,110 @@ struct DefaultVectorVector {
|
||||
b++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
while (size-- > 0) {
|
||||
auto dst = op(*a, *b);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
template <typename T, typename U, typename Op, int D, bool Strided>
|
||||
void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
dst += stride;
|
||||
}
|
||||
}
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
dst += stride;
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[2];
|
||||
b_idx += b.strides()[2];
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
} else {
|
||||
if constexpr (Strided) {
|
||||
op(a, b, out, stride_out);
|
||||
} else {
|
||||
*out = op(*a, *b);
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
out += stride_out;
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[3];
|
||||
b_idx += b.strides()[3];
|
||||
}
|
||||
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op) {
|
||||
switch (out.ndim()) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims3<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
case 4:
|
||||
binary_op_dims4<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); i++) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, bool Strided, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
int stride) {
|
||||
// Number of dimensions to loop over for vectorized ops
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out, op, stride);
|
||||
binary_op_dims<T, U, Op, 1, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out, op, stride);
|
||||
binary_op_dims<T, U, Op, 2, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); i += stride) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||
dst += stride;
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
|
||||
size_t stride = out_strides[dim - 4];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
dim - 3);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,29 +320,33 @@ void binary_op(
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
const auto& a_strides = new_strides[0];
|
||||
const auto& b_strides = new_strides[1];
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto& strides = out.strides();
|
||||
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a);
|
||||
auto b_rc_dim = leftmost_rc_dim(b);
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a);
|
||||
auto b_s_dim = leftmost_s_dim(b);
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = out.ndim();
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
@@ -492,27 +368,27 @@ void binary_op(
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, op);
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -529,9 +405,9 @@ void binary_op(
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
@@ -552,7 +428,8 @@ void binary_op(
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
@@ -567,7 +444,8 @@ void binary_op(
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
@@ -583,7 +461,8 @@ void binary_op(
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||
|
@@ -9,168 +9,43 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < out_a.size(); ++i) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[i] = dst.first;
|
||||
dst_b[i] = dst.second;
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
template <typename T, typename U, typename Op, int D>
|
||||
void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out_a,
|
||||
U* out_b,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
dst_a += stride;
|
||||
dst_b += stride;
|
||||
}
|
||||
}
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[out_idx] = dst.first;
|
||||
dst_b[out_idx++] = dst.second;
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1>(
|
||||
a,
|
||||
b,
|
||||
out_a,
|
||||
out_b,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
axis + 1);
|
||||
} else {
|
||||
std::tie(*out_a, *out_b) = op(*a, *b);
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
dst_a += stride;
|
||||
dst_b += stride;
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims3(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[out_idx] = dst.first;
|
||||
dst_b[out_idx++] = dst.second;
|
||||
a_idx += a.strides()[2];
|
||||
b_idx += b.strides()[2];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims4(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[out_idx] = dst.first;
|
||||
dst_b[out_idx++] = dst.second;
|
||||
a_idx += a.strides()[3];
|
||||
b_idx += b.strides()[3];
|
||||
}
|
||||
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
out_a += stride_out;
|
||||
out_b += stride_out;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,352 +56,160 @@ void binary_op_dispatch_dims(
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
switch (out_a.ndim()) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out_a.strides()});
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_a_ptr = out_a.data<U>();
|
||||
U* out_b_ptr = out_b.data<U>();
|
||||
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
|
||||
binary_op_dims<T, U, Op, 1>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
case 4:
|
||||
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
for (size_t i = 0; i < out_a.size(); i++) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
||||
size_t stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_a_ptr + elem,
|
||||
out_b_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
ndim - 2);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
int dim,
|
||||
int stride) {
|
||||
// Number of dimensions to loop over for vectorized ops
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
for (size_t i = 0; i < out_a.size(); i += stride) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||
dst_a += stride;
|
||||
dst_b += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename OpSV,
|
||||
typename OpVS,
|
||||
typename OpVV>
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::General) {
|
||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
}
|
||||
|
||||
auto a_ptr = a.data<T>();
|
||||
auto b_ptr = b.data<T>();
|
||||
auto out_a_ptr = out_a.data<U>();
|
||||
auto out_b_ptr = out_b.data<U>();
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
|
||||
op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
out_a.data<U>(),
|
||||
out_b.data<U>(),
|
||||
b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
out_a.data<U>(),
|
||||
out_b.data<U>(),
|
||||
a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
out_a.data<U>(),
|
||||
out_b.data<U>(),
|
||||
out_a.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto& strides = out_a.strides();
|
||||
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||
for (size_t i = 0; i < b.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a);
|
||||
auto b_rc_dim = leftmost_rc_dim(b);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
}
|
||||
} else { // VectorVector
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a);
|
||||
auto b_s_dim = leftmost_s_dim(b);
|
||||
|
||||
auto ndim = out_a.ndim();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv and opvs were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opvs was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
opsv,
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// All ops provided
|
||||
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
DefaultScalarVector<T, T, Op> opsv(op);
|
||||
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||
DefaultVectorVector<T, T, Op> opvv(op);
|
||||
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename... Ops>
|
||||
template <typename Op>
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Ops... ops) {
|
||||
Op op) {
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, ops...);
|
||||
binary_op<bool>(a, b, outputs, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, ops...);
|
||||
binary_op<uint8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, ops...);
|
||||
binary_op<uint16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, ops...);
|
||||
binary_op<uint32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, ops...);
|
||||
binary_op<uint64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, ops...);
|
||||
binary_op<int8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, ops...);
|
||||
binary_op<int16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, ops...);
|
||||
binary_op<int32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, ops...);
|
||||
binary_op<int64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, ops...);
|
||||
binary_op<float16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, ops...);
|
||||
binary_op<float>(a, b, outputs, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, ops...);
|
||||
binary_op<bfloat16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, outputs, ops...);
|
||||
binary_op<complex64_t>(a, b, outputs, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
74
mlx/backend/common/cholesky.cpp
Normal file
74
mlx/backend/common/cholesky.cpp
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
// (A)ᵀ = A
|
||||
// and that a column-major lower triangular matrix is a row-major upper
|
||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||
// upper
|
||||
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
float* matrix = factor.data<float>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(spotrf)
|
||||
(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
}
|
||||
cholesky_impl(inputs[0], output, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
303
mlx/backend/common/common.cpp
Normal file
303
mlx/backend/common/common.cpp
Normal file
@@ -0,0 +1,303 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (!in.flags().row_contiguous) {
|
||||
// Just ensuring that inputs[0] came from the ops which would ensure the
|
||||
// input is row contiguous.
|
||||
throw std::runtime_error(
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
void Depends::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
numel *= inputs[0].shape(ax);
|
||||
}
|
||||
|
||||
if (inverted_) {
|
||||
numel = 1.0 / numel;
|
||||
}
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
*out.data<bool>() = static_cast<bool>(numel);
|
||||
break;
|
||||
case uint8:
|
||||
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
||||
break;
|
||||
case uint16:
|
||||
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
||||
break;
|
||||
case uint32:
|
||||
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
||||
break;
|
||||
case uint64:
|
||||
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
||||
break;
|
||||
case int8:
|
||||
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
||||
break;
|
||||
case int16:
|
||||
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
||||
break;
|
||||
case int32:
|
||||
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
||||
break;
|
||||
case int64:
|
||||
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
||||
break;
|
||||
case float16:
|
||||
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
||||
break;
|
||||
case float32:
|
||||
*out.data<float>() = static_cast<float>(numel);
|
||||
break;
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
}
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
int N = out.shape(i);
|
||||
if (j < shape.size() && shape[j] % N == 0) {
|
||||
shape[j] /= N;
|
||||
out_strides.push_back(shape[j] * strides[j]);
|
||||
j += (shape[j] == 1);
|
||||
} else if (N == 1) {
|
||||
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
|
||||
out_strides.push_back(out_strides.back());
|
||||
} else {
|
||||
copy_necessary = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
// For row contiguous reshapes:
|
||||
// - Shallow copy the buffer
|
||||
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||
// becomes col contiguous again.
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
auto compute_new_flags = [](const auto& shape,
|
||||
const auto& strides,
|
||||
size_t in_data_size,
|
||||
auto flags) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in_data_size) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
return std::pair<decltype(flags), size_t>{flags, data_size};
|
||||
};
|
||||
|
||||
std::vector<int> indices(1, 0);
|
||||
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
size_t offset = indices[i] * in.strides()[axis_];
|
||||
auto [new_flags, data_size] = compute_new_flags(
|
||||
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||
outputs[i].copy_shared_buffer(
|
||||
in, in.strides(), new_flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -18,7 +18,8 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
case complex64:
|
||||
return print_complex_constant<complex64_t>(os, x);
|
||||
case int8:
|
||||
return print_int_constant<int8_t>(os, x);
|
||||
os << static_cast<int32_t>(x.item<int8_t>());
|
||||
return;
|
||||
case int16:
|
||||
return print_int_constant<int16_t>(os, x);
|
||||
case int32:
|
||||
@@ -26,7 +27,8 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
case int64:
|
||||
return print_int_constant<int64_t>(os, x);
|
||||
case uint8:
|
||||
return print_int_constant<uint8_t>(os, x);
|
||||
os << static_cast<uint32_t>(x.item<uint8_t>());
|
||||
return;
|
||||
case uint16:
|
||||
return print_int_constant<uint16_t>(os, x);
|
||||
case uint32:
|
||||
@@ -205,8 +207,8 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
|
@@ -2,7 +2,10 @@
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/compiled_preamble.h"
|
||||
@@ -11,22 +14,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
namespace detail {
|
||||
bool compile_available_for_device(const Device& device) {
|
||||
return true;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name);
|
||||
}
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::string& source_code = "") {
|
||||
struct CompilerCache {
|
||||
struct DLib {
|
||||
DLib(const std::string& libname) {
|
||||
lib = dlopen(libname.c_str(), RTLD_NOW);
|
||||
@@ -43,15 +31,41 @@ void* compile(
|
||||
void* lib;
|
||||
};
|
||||
// Statics to cache compiled libraries and functions
|
||||
static std::list<DLib> libs;
|
||||
static std::unordered_map<std::string, void*> kernels;
|
||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
if (source_code.empty()) {
|
||||
return nullptr;
|
||||
std::list<DLib> libs;
|
||||
std::unordered_map<std::string, void*> kernels;
|
||||
std::shared_mutex mtx;
|
||||
};
|
||||
|
||||
static CompilerCache cache{};
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
namespace detail {
|
||||
bool compile_available_for_device(const Device& device) {
|
||||
return true;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name);
|
||||
}
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
{
|
||||
std::shared_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string source_code = source_builder();
|
||||
std::string kernel_file_name;
|
||||
|
||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||
@@ -89,8 +103,8 @@ void* compile(
|
||||
source_file.close();
|
||||
|
||||
std::ostringstream build_command;
|
||||
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
|
||||
<< source_file_path << " -o " << shared_lib_path;
|
||||
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
|
||||
<< source_file_path << "' -o '" << shared_lib_path << "'";
|
||||
std::string build_command_str = build_command.str();
|
||||
auto return_code = system(build_command_str.c_str());
|
||||
if (return_code) {
|
||||
@@ -102,10 +116,10 @@ void* compile(
|
||||
}
|
||||
|
||||
// load library
|
||||
libs.emplace_back(shared_lib_path);
|
||||
cache.libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
|
||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
@@ -113,7 +127,7 @@ void* compile(
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
kernels.insert({kernel_name, fun});
|
||||
cache.kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
@@ -315,10 +329,7 @@ void Compiled::eval_cpu(
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name);
|
||||
|
||||
// If it doesn't exist, compile it
|
||||
if (fn_ptr == nullptr) {
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@@ -333,10 +344,8 @@ void Compiled::eval_cpu(
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
kernel << "}" << std::endl;
|
||||
|
||||
// Compile and get function pointer
|
||||
fn_ptr = compile(kernel_name, kernel.str());
|
||||
}
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
|
@@ -3,13 +3,8 @@
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -111,13 +106,17 @@ void slow_conv_2D(
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int C = in.shape(3); // In channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(3); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
const int wW = wt.shape(2); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(3);
|
||||
const int C_per_group = wt.shape(3);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_H = in.strides()[1];
|
||||
const size_t in_stride_W = in.strides()[2];
|
||||
@@ -141,33 +140,35 @@ void slow_conv_2D(
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
@@ -219,41 +220,43 @@ void slow_conv_2D(
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||
++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int oH_border_0 = 0;
|
||||
@@ -310,6 +313,296 @@ void slow_conv_2D(
|
||||
} // n
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const T* st_wt_ptr = wt.data<T>();
|
||||
const T* st_in_ptr = in.data<T>();
|
||||
T* st_out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
|
||||
const int oD = out.shape(1); // Output spatial dim
|
||||
const int oH = out.shape(2); // Output spatial dim
|
||||
const int oW = out.shape(3); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(4); // In channels
|
||||
const int wD = wt.shape(1); // Weight spatial dim
|
||||
const int wH = wt.shape(2); // Weight spatial dim
|
||||
const int wW = wt.shape(3); // Weight spatial dim
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_D = in.strides()[1];
|
||||
const size_t in_stride_H = in.strides()[2];
|
||||
const size_t in_stride_W = in.strides()[3];
|
||||
const size_t in_stride_C = in.strides()[4];
|
||||
|
||||
const size_t wt_stride_O = wt.strides()[0];
|
||||
const size_t wt_stride_D = wt.strides()[1];
|
||||
const size_t wt_stride_H = wt.strides()[2];
|
||||
const size_t wt_stride_W = wt.strides()[3];
|
||||
const size_t wt_stride_C = wt.strides()[4];
|
||||
|
||||
const size_t out_stride_N = out.strides()[0];
|
||||
const size_t out_stride_D = out.strides()[1];
|
||||
const size_t out_stride_H = out.strides()[2];
|
||||
const size_t out_stride_W = out.strides()[3];
|
||||
const size_t out_stride_O = out.strides()[4];
|
||||
|
||||
bool is_idil_one =
|
||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
|
||||
|
||||
auto pt_conv_no_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int od,
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wd = 0; wd < wD; ++wd) {
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int id = id_base + wd_flip * wt_dilation[0];
|
||||
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // ww
|
||||
} // wh
|
||||
} // wd
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||
int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
|
||||
|
||||
int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
|
||||
int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
|
||||
int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
|
||||
|
||||
int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||
int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||
int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
|
||||
|
||||
int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
|
||||
|
||||
std::vector<int> base_d(f_out_jump_d);
|
||||
std::vector<int> base_h(f_out_jump_h);
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
||||
|
||||
int wd_base = 0;
|
||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||
wd_base++;
|
||||
id_loop += jump_d;
|
||||
}
|
||||
|
||||
base_d[i] = wd_base;
|
||||
}
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
|
||||
base_h[i] = wh_base;
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
|
||||
base_w[j] = ww_base;
|
||||
}
|
||||
|
||||
auto pt_conv_all_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int od,
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
|
||||
int wd_base = base_d[od % f_out_jump_d];
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int id = id_base + wd_flip * wt_dilation[0];
|
||||
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||
|
||||
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
|
||||
iw < iW) {
|
||||
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
|
||||
wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
|
||||
|
||||
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
|
||||
ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // iD, ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
} // wd
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
int oD_border_0 = 0;
|
||||
int oD_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
||||
int oD_border_2 = std::max(
|
||||
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||
int oD_border_3 = oD;
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Case 1: od might put us out of bounds
|
||||
for (int od = oD_border_0; od < oD_border_1; ++od) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
// Case 2: od in bounds
|
||||
for (int od = oD_border_1; od < oD_border_2; ++od) {
|
||||
// Case 2.1: oh might put us out of bounds
|
||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2.2: oh in bounds
|
||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||
// Case 2.2.1: ow might put us out of bounds
|
||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case 2.2.2: ow in bounds
|
||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case 2.2.3: ow might put us out of bounds
|
||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2.3: oh might put us out of bounds
|
||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
// Case 3: od might put us out of bounds
|
||||
for (int od = oD_border_2; od < oD_border_3; ++od) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
st_in_ptr += in_stride_N;
|
||||
st_out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
@@ -358,10 +651,60 @@ void dispatch_slow_conv_2D(
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_3D<float>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_3D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_3D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Explicit gemm conv
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
void flip_spatial_dims_inplace(array& wt) {
|
||||
T* x = wt.data<T>();
|
||||
size_t out_channels = wt.shape(0);
|
||||
size_t in_channels = wt.shape(-1);
|
||||
|
||||
// Calculate the total size of the spatial dimensions
|
||||
int spatial_size = 1;
|
||||
for (int d = 1; d < wt.ndim() - 1; ++d) {
|
||||
spatial_size *= wt.shape(d);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < out_channels; i++) {
|
||||
T* top = x + i * spatial_size * in_channels;
|
||||
T* bottom =
|
||||
x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;
|
||||
for (size_t j = 0; j < spatial_size / 2; j++) {
|
||||
for (size_t k = 0; k < in_channels; k++) {
|
||||
std::swap(top[k], bottom[k]);
|
||||
}
|
||||
top += in_channels;
|
||||
bottom -= in_channels;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_1D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
@@ -582,6 +925,140 @@ void explicit_gemm_conv_2D_cpu(
|
||||
}
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_ND_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const bool flip) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const auto iDim = std::vector<int>(
|
||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = std::vector<int>(
|
||||
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(-1); // In channels
|
||||
const auto wDim = std::vector<int>(
|
||||
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||
}
|
||||
padded_shape.back() = C;
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = 0;
|
||||
for (size_t i = 0; i < padding.size(); i++) {
|
||||
data_offset += padding[i] * in_padded.strides()[i + 1];
|
||||
}
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
}
|
||||
for (size_t i = 0; i < wDim.size(); i++) {
|
||||
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
}
|
||||
for (size_t i = 1; i < in_padded.strides().size(); i++) {
|
||||
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
||||
}
|
||||
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N, C};
|
||||
for (const auto& o : oDim) {
|
||||
strided_reshape[0] *= o;
|
||||
}
|
||||
for (const auto& w : wDim) {
|
||||
strided_reshape[1] *= w;
|
||||
}
|
||||
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy(wt, gemm_wt, ctype);
|
||||
}
|
||||
|
||||
if (flip) {
|
||||
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
||||
copy(gemm_wt, gemm_wt_, CopyType::Vector);
|
||||
|
||||
flip_spatial_dims_inplace<float>(gemm_wt_);
|
||||
gemm_wt = gemm_wt_;
|
||||
}
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>(),
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt.data<float>(),
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>(),
|
||||
O // ldc
|
||||
);
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Conv routing
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -595,10 +1072,15 @@ void conv_1D_cpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const int groups = in.shape().back() / wt.shape().back();
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||
return explicit_gemm_conv_1D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
}
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_1D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
@@ -613,10 +1095,38 @@ void conv_2D_cpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const int groups = in.shape().back() / wt.shape().back();
|
||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
||||
in_dilation[1] == 1 && groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_2D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
void conv_3D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const int groups = in.shape().back() / wt.shape().back();
|
||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&
|
||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
||||
groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_3D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -625,8 +1135,20 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
auto& wt = inputs[1];
|
||||
|
||||
// 3D convolution
|
||||
if (in.ndim() == (3 + 2)) {
|
||||
return conv_3D_cpu(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// 2D convolution
|
||||
if (in.ndim() == (2 + 2)) {
|
||||
else if (in.ndim() == (2 + 2)) {
|
||||
return conv_2D_cpu(
|
||||
in,
|
||||
wt,
|
||||
@@ -653,7 +1175,7 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
else {
|
||||
std::ostringstream msg;
|
||||
msg << "[Convolution::eval] Convolution currently only supports"
|
||||
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
|
||||
<< " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
|
||||
<< " spatial dimensions";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -25,252 +26,117 @@ void copy_vector(const array& src, array& dst) {
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim1(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += i_strides[0];
|
||||
}
|
||||
}
|
||||
template <typename SrcT, typename DstT, typename StrideT, int D>
|
||||
inline void copy_dims(
|
||||
const SrcT* src,
|
||||
DstT* dst,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
int axis) {
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_dim1(const array& src, array& dst) {
|
||||
return copy_general_dim1<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim2(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
for (int j = 0; j < data_shape[1]; ++j) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += i_strides[1];
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
copy_dims<SrcT, DstT, StrideT, D - 1>(
|
||||
src, dst, shape, i_strides, o_strides, axis + 1);
|
||||
} else {
|
||||
*dst = static_cast<DstT>(*src);
|
||||
}
|
||||
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||
src += stride_src;
|
||||
dst += stride_dst;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_dim2(const array& src, array& dst) {
|
||||
return copy_general_dim2<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim3(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
for (int j = 0; j < data_shape[1]; ++j) {
|
||||
for (int k = 0; k < data_shape[2]; ++k) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += i_strides[2];
|
||||
}
|
||||
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||
}
|
||||
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_dim3(const array& src, array& dst) {
|
||||
return copy_general_dim3<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim4(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
for (int j = 0; j < data_shape[1]; ++j) {
|
||||
for (int k = 0; k < data_shape[2]; ++k) {
|
||||
for (int ii = 0; ii < data_shape[3]; ++ii) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += i_strides[3];
|
||||
}
|
||||
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
|
||||
}
|
||||
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||
}
|
||||
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_dim4(const array& src, array& dst) {
|
||||
return copy_general_dim4<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_dim1<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_dim2<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_dim3<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_dim4<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
|
||||
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
return copy_general<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
inline void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
return copy_general<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t, int D>
|
||||
inline void copy_general_general_dims(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
if constexpr (D > 1) {
|
||||
int axis = src.ndim() - D;
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
for (int i = 0; i < N; i++) {
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
i_offset += stride_src;
|
||||
o_offset += stride_dst;
|
||||
}
|
||||
} else {
|
||||
int axis = src.ndim() - 1;
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
|
||||
DstT* dst_ptr = dst.data<DstT>() + o_offset;
|
||||
for (int i = 0; i < N; i++) {
|
||||
*dst_ptr = static_cast<DstT>(*src_ptr);
|
||||
src_ptr += stride_src;
|
||||
dst_ptr += stride_dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
template <typename SrcT, typename DstT, typename StrideT>
|
||||
void copy_general_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 5:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
if (data_shape.empty()) {
|
||||
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
|
||||
int size = std::accumulate(
|
||||
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < src.size(); i += size) {
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
int ndim = shape.size();
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, StrideT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, StrideT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
}
|
||||
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
|
||||
StrideT stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
|
||||
for (StrideT elem = 0; elem < src.size(); elem += stride) {
|
||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
shape,
|
||||
strides[0],
|
||||
strides[1],
|
||||
ndim - 3);
|
||||
in.step();
|
||||
out.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
return copy_general_general<SrcT, DstT, size_t>(
|
||||
copy_general_general<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename StrideT>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>&,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
copy_general_general<SrcT, DstT, StrideT>(
|
||||
src,
|
||||
dst,
|
||||
data_shape,
|
||||
i_strides,
|
||||
make_contiguous_strides<StrideT>(data_shape),
|
||||
i_offset,
|
||||
o_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT, size_t>(
|
||||
src,
|
||||
dst,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
make_contiguous_strides<size_t>(src.shape()),
|
||||
0,
|
||||
0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (ctype) {
|
||||
@@ -285,6 +151,7 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,7 +252,7 @@ inline void copy_inplace_dispatch(
|
||||
} // namespace
|
||||
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
@@ -415,20 +282,20 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
copy_inplace(src, dst, ctype);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
template <typename StrideT>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
const std::vector<StrideT>& i_strides,
|
||||
const std::vector<StrideT>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
return copy_inplace_dispatch(
|
||||
copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
@@ -437,15 +304,24 @@ void copy_inplace(
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void copy_inplace<int64_t>(
|
||||
template void copy_inplace<size_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<size_t>& i_strides,
|
||||
const std::vector<size_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
template void copy_inplace<int64_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
@@ -453,24 +329,6 @@ void copy_inplace<int64_t>(
|
||||
const std::vector<int64_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
return copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
}
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,15 +1,10 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -34,6 +29,7 @@ DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTan2)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
@@ -42,15 +38,17 @@ DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(BlockSparseMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
@@ -66,6 +64,7 @@ DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
@@ -110,6 +109,8 @@ DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
namespace {
|
||||
|
||||
|
117
mlx/backend/common/eigh.cpp
Normal file
117
mlx/backend/common/eigh.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void ssyevd(
|
||||
char jobz,
|
||||
char uplo,
|
||||
float* a,
|
||||
int N,
|
||||
float* w,
|
||||
float* work,
|
||||
int lwork,
|
||||
int* iwork,
|
||||
int liwork) {
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(ssyevd)
|
||||
(
|
||||
/* jobz = */ &jobz,
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ a,
|
||||
/* lda = */ &N,
|
||||
/* w = */ w,
|
||||
/* work = */ work,
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ iwork,
|
||||
/* liwork = */ &liwork,
|
||||
/* info = */ &info);
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
auto vectors = compute_eigenvectors_
|
||||
? outputs[1]
|
||||
: array(a.shape(), a.dtype(), nullptr, {});
|
||||
|
||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
||||
|
||||
copy(
|
||||
a,
|
||||
vectors,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
// are in the columns of the output
|
||||
auto flags = vectors.flags();
|
||||
auto strides = vectors.strides();
|
||||
auto ndim = a.ndim();
|
||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||
|
||||
if (a.size() > 1) {
|
||||
flags.row_contiguous = false;
|
||||
if (ndim > 2) {
|
||||
flags.col_contiguous = false;
|
||||
} else {
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||
}
|
||||
|
||||
auto vec_ptr = vectors.data<float>();
|
||||
auto eig_ptr = values.data<float>();
|
||||
|
||||
char jobz = compute_eigenvectors_ ? 'V' : 'N';
|
||||
auto N = a.shape(-1);
|
||||
|
||||
// Work query
|
||||
int lwork;
|
||||
int liwork;
|
||||
{
|
||||
float work;
|
||||
int iwork;
|
||||
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < a.size() / (N * N); ++i) {
|
||||
ssyevd(
|
||||
jobz,
|
||||
uplo_[0],
|
||||
vec_ptr,
|
||||
N,
|
||||
eig_ptr,
|
||||
static_cast<float*>(work_buf.buffer.raw_ptr()),
|
||||
lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
liwork);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
107
mlx/backend/common/hadamard.cpp
Normal file
107
mlx/backend/common/hadamard.cpp
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// n = 2^k component
|
||||
template <typename T>
|
||||
void hadamard_n(array& out, int n, int m, float scale) {
|
||||
for (int b = 0; b < out.size() / n; b++) {
|
||||
size_t loc = b * n;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
int h = 1;
|
||||
int n_over_2 = n / 2;
|
||||
while (h < n) {
|
||||
for (int i = 0; i < n / 2; i++) {
|
||||
int k = i & (h - 1);
|
||||
int j = ((i - k) << 1) + k;
|
||||
float x = *(data_ptr + j);
|
||||
float y = *(data_ptr + j + h);
|
||||
*(data_ptr + j) = x + y;
|
||||
*(data_ptr + j + h) = x - y;
|
||||
if (h == n_over_2) {
|
||||
*(data_ptr + j) *= scale;
|
||||
*(data_ptr + j + h) *= scale;
|
||||
}
|
||||
}
|
||||
h <<= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// m component
|
||||
template <typename T>
|
||||
void hadamard_m(array& out, int n, int m, float scale) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
auto start = 1;
|
||||
auto end = matrix.find('\n', start);
|
||||
std::vector<bool> hmat_vec;
|
||||
while (end != std::string_view::npos) {
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
hmat_vec.push_back(row[i] == '+');
|
||||
}
|
||||
start = end + 1;
|
||||
end = matrix.find('\n', start);
|
||||
}
|
||||
|
||||
for (int b = 0; b < out.size() / m / n; b++) {
|
||||
size_t loc = b * n * m;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
for (int i = 0; i < n; i++) {
|
||||
std::vector<float> out(m);
|
||||
for (int j = 0; j < m; j++) {
|
||||
for (int k = 0; k < m; k++) {
|
||||
float x = *(data_ptr + i + k * n);
|
||||
if (hmat_vec[k + j * m]) {
|
||||
out[j] += x;
|
||||
} else {
|
||||
out[j] -= x;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < m; j++) {
|
||||
*(data_ptr + i + j * n) = out[j] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void hadamard(array& out, int n, int m, float scale) {
|
||||
float n_scale = m > 1 ? 1.0 : scale;
|
||||
hadamard_n<T>(out, n, m, n_scale);
|
||||
if (m > 1) {
|
||||
hadamard_m<T>(out, n, m, scale);
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
copy(in, out, CopyType::General);
|
||||
|
||||
int axis = out.ndim() - 1;
|
||||
auto [n, m] = decompose_hadamard(out.shape(axis));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
return hadamard<float>(out, n, m, scale_);
|
||||
case float16:
|
||||
return hadamard<float16_t>(out, n, m, scale_);
|
||||
case bfloat16:
|
||||
return hadamard<bfloat16_t>(out, n, m, scale_);
|
||||
default:
|
||||
throw std::invalid_argument("[hadamard] Unsupported type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
105
mlx/backend/common/hadamard.h
Normal file
105
mlx/backend/common/hadamard.h
Normal file
@@ -0,0 +1,105 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// From http://neilsloane.com/hadamard/
|
||||
constexpr std::string_view h12 = R"(
|
||||
+-++++++++++
|
||||
--+-+-+-+-+-
|
||||
+++-++----++
|
||||
+---+--+-++-
|
||||
+++++-++----
|
||||
+-+---+--+-+
|
||||
++--+++-++--
|
||||
+--++---+--+
|
||||
++----+++-++
|
||||
+--+-++---+-
|
||||
++++----+++-
|
||||
+-+--+-++---
|
||||
)";
|
||||
|
||||
constexpr std::string_view h20 = R"(
|
||||
+----+----++--++-++-
|
||||
-+----+---+++---+-++
|
||||
--+----+---+++-+-+-+
|
||||
---+----+---+++++-+-
|
||||
----+----++--++-++-+
|
||||
-+++++-----+--+++--+
|
||||
+-+++-+---+-+--+++--
|
||||
++-++--+---+-+--+++-
|
||||
+++-+---+---+-+--+++
|
||||
++++-----++--+-+--++
|
||||
--++-+-++-+-----++++
|
||||
---++-+-++-+---+-+++
|
||||
+---++-+-+--+--++-++
|
||||
++---++-+----+-+++-+
|
||||
-++---++-+----+++++-
|
||||
-+--+--++-+----+----
|
||||
+-+-----++-+----+---
|
||||
-+-+-+---+--+----+--
|
||||
--+-+++------+----+-
|
||||
+--+--++------+----+
|
||||
)";
|
||||
|
||||
constexpr std::string_view h28 = R"(
|
||||
+------++----++-+--+-+--++--
|
||||
-+-----+++-----+-+--+-+--++-
|
||||
--+-----+++---+-+-+----+--++
|
||||
---+-----+++---+-+-+-+--+--+
|
||||
----+-----+++---+-+-+++--+--
|
||||
-----+-----++++--+-+--++--+-
|
||||
------++----++-+--+-+--++--+
|
||||
--++++-+-------++--+++-+--+-
|
||||
---++++-+-----+-++--+-+-+--+
|
||||
+---+++--+----++-++--+-+-+--
|
||||
++---++---+----++-++--+-+-+-
|
||||
+++---+----+----++-++--+-+-+
|
||||
++++--------+-+--++-++--+-+-
|
||||
-++++--------+++--++--+--+-+
|
||||
-+-++-++--++--+--------++++-
|
||||
+-+-++--+--++--+--------++++
|
||||
-+-+-++--+--++--+----+---+++
|
||||
+-+-+-++--+--+---+---++---++
|
||||
++-+-+-++--+------+--+++---+
|
||||
-++-+-+-++--+------+-++++---
|
||||
+-++-+---++--+------+-++++--
|
||||
-++--++-+-++-+++----++------
|
||||
+-++--++-+-++-+++-----+-----
|
||||
++-++---+-+-++-+++-----+----
|
||||
-++-++-+-+-+-+--+++-----+---
|
||||
--++-++++-+-+----+++-----+--
|
||||
+--++-+-++-+-+----+++-----+-
|
||||
++--++-+-++-+-+----++------+
|
||||
)";
|
||||
|
||||
inline const std::map<int, std::string_view> hadamard_matrices() {
|
||||
return {{12, h12}, {20, h20}, {28, h28}};
|
||||
}
|
||||
|
||||
inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
// n = m*2^k
|
||||
int m = 1;
|
||||
if (!is_power_of_2(n)) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
for (auto [factor, _] : h_matrices) {
|
||||
if (n % factor == 0) {
|
||||
m = factor;
|
||||
n /= factor;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (m == 1) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -81,11 +80,18 @@ void gather(
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = std::move(
|
||||
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
|
||||
}
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = elem_to_loc(idx, inds[ii]);
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||
src_idx += (idx_val * src.strides()[ax]);
|
||||
@@ -99,9 +105,10 @@ void gather(
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
src_it.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,21 +230,29 @@ void scatter(
|
||||
update_size *= us;
|
||||
}
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> update_it(updates);
|
||||
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < nind; ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = elem_to_loc(i, inds[j]);
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
auto update_loc = elem_to_loc(i * update_size + j, updates);
|
||||
auto out_loc = elem_to_loc(j, update_shape, out.strides());
|
||||
op(updates.data<InT>()[update_loc],
|
||||
out.data<InT>() + out_offset + out_loc);
|
||||
op(updates.data<InT>()[update_it.loc],
|
||||
out.data<InT>() + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
out_it.reset();
|
||||
update_it.reset();
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -2,18 +2,94 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(strtri)
|
||||
(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
return info;
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void inverse_impl(const array& a, array& inv) {
|
||||
void general_inv(array& inv, int N, int i) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
const char uplo = upper ? 'L' : 'U';
|
||||
const char diag = 'N';
|
||||
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: triangular inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
// https://math.stackexchange.com/a/340234):
|
||||
@@ -25,63 +101,11 @@ void inverse_impl(const array& a, array& inv) {
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch =
|
||||
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
if (tri) {
|
||||
tri_inv(inv, N, i, upper);
|
||||
} else {
|
||||
general_inv(inv, N, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,15 +114,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
}
|
||||
inverse_impl(inputs[0], output);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::inv(a, stream())}, {ax}};
|
||||
inverse_impl(inputs[0], output, tri_, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,10 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
@@ -5,11 +5,9 @@
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
@@ -29,12 +27,14 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
namespace mlx::core {
|
||||
|
||||
reader_->seek(offset_, std::ios_base::beg);
|
||||
reader_->read(out.data<char>(), out.nbytes());
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianness_) {
|
||||
reader->read(out.data<char>(), out.nbytes(), offset);
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
@@ -51,4 +51,11 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
14
mlx/backend/common/load.h
Normal file
14
mlx/backend/common/load.h
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/io/load.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianess);
|
||||
|
||||
} // namespace mlx::core
|
@@ -18,16 +18,19 @@ if [ "$CLANG" = "TRUE" ]; then
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
EOM
|
||||
|
||||
CC_FLAGS=""
|
||||
else
|
||||
CC_FLAGS="-std=c++17"
|
||||
fi
|
||||
|
||||
CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null)
|
||||
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
const char* get_kernel_preamble() {
|
||||
return R"preamble(
|
||||
$INCLUDES
|
||||
$CONTENT
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::detail;
|
||||
)preamble";
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user