mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 09:14:34 +08:00
Compare commits
294 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bf7cd29970 | ||
![]() |
a000d2288c | ||
![]() |
165abf0e4c | ||
![]() |
818cda16bc | ||
![]() |
85143fecdd | ||
![]() |
35431a4ac8 | ||
![]() |
ccf1645995 | ||
![]() |
1a48713d32 | ||
![]() |
1eb04aa23f | ||
![]() |
0c65517e91 | ||
![]() |
2fdc2462c3 | ||
![]() |
be6e9d6a9f | ||
![]() |
e54cbb7ba6 | ||
![]() |
40c108766b | ||
![]() |
4cc70290f7 | ||
![]() |
74caa68d02 | ||
![]() |
3756381358 | ||
![]() |
d12573daa6 | ||
![]() |
0dbc4c7547 | ||
![]() |
06072601ce | ||
![]() |
11d2c8f7a1 | ||
![]() |
7f3f8d8f8d | ||
![]() |
b96be943dc | ||
![]() |
b670485185 | ||
![]() |
b57bd0488d | ||
![]() |
221f8d3fc2 | ||
![]() |
5c03efaf29 | ||
![]() |
7dccd42133 | ||
![]() |
1b97b2958b | ||
![]() |
e5e816a5ef | ||
![]() |
28eac18571 | ||
![]() |
5fd11c347d | ||
![]() |
ef73393a19 | ||
![]() |
ea406d5e33 | ||
![]() |
146bd69470 | ||
![]() |
316ff490b3 | ||
![]() |
d40a04f8dc | ||
![]() |
d75ae52ecd | ||
![]() |
31fea3758e | ||
![]() |
e319383ef9 | ||
![]() |
5c3ac52dd7 | ||
![]() |
ebfd3618b0 | ||
![]() |
11a9fd40f0 | ||
![]() |
4fd2fb84a6 | ||
![]() |
9852af1a19 | ||
![]() |
16750f3c51 | ||
![]() |
95b5fb8245 | ||
![]() |
83f63f2184 | ||
![]() |
cb6156d35d | ||
![]() |
506d43035c | ||
![]() |
36cff34701 | ||
![]() |
e88e474fd1 | ||
![]() |
601c6d6aa8 | ||
![]() |
ba8d6bf365 | ||
![]() |
4a5f3b21bb | ||
![]() |
fcc5ac1c64 | ||
![]() |
bad67fec37 | ||
![]() |
199aebcf77 | ||
![]() |
0de5988f92 | ||
![]() |
143e2690d5 | ||
![]() |
375446453e | ||
![]() |
1895d34c20 | ||
![]() |
09b9275027 | ||
![]() |
d3a9005454 | ||
![]() |
3f7aba8498 | ||
![]() |
65d0b8df9f | ||
![]() |
3c2f192345 | ||
![]() |
37d98ba6ff | ||
![]() |
8993382aaa | ||
![]() |
07f35c9d8a | ||
![]() |
bf17ab5002 | ||
![]() |
8fa6b322b9 | ||
![]() |
874b739f3c | ||
![]() |
077c1ee64a | ||
![]() |
2463496471 | ||
![]() |
87b7fa9ba2 | ||
![]() |
624065c074 | ||
![]() |
f27ec5e097 | ||
![]() |
f30e63353a | ||
![]() |
4fe2fa2a64 | ||
![]() |
37fc9db82c | ||
![]() |
755dcf6137 | ||
![]() |
6b4b30e3fc | ||
![]() |
86e0c79467 | ||
![]() |
98c37d3a22 | ||
![]() |
f326dd8334 | ||
![]() |
6d3bee3364 | ||
![]() |
ecb174ca9d | ||
![]() |
7a34e46677 | ||
![]() |
92c22c1ea3 | ||
![]() |
d52383367a | ||
![]() |
363d3add6d | ||
![]() |
b207c2c86b | ||
![]() |
6bf779e72b | ||
![]() |
ddf50113c5 | ||
![]() |
6589c869d6 | ||
![]() |
f6feb61f92 | ||
![]() |
c4ec836523 | ||
![]() |
550d4bf7c0 | ||
![]() |
f6e911ced0 | ||
![]() |
3d99a8d31d | ||
![]() |
a749a91c75 | ||
![]() |
49a52610b7 | ||
![]() |
d1fef34138 | ||
![]() |
9c111f176d | ||
![]() |
78e5f2d17d | ||
![]() |
90c234b7ac | ||
![]() |
135fd796d2 | ||
![]() |
78102a47ad | ||
![]() |
556cdf0e06 | ||
![]() |
275db7221a | ||
![]() |
4a9012cba0 | ||
![]() |
a2bf7693dd | ||
![]() |
d8fabaa12b | ||
![]() |
4e290d282f | ||
![]() |
e72458a3fa | ||
![]() |
a2ffea683a | ||
![]() |
c15fe3e61b | ||
![]() |
f44c132f4a | ||
![]() |
92a2fdd577 | ||
![]() |
6022d4129e | ||
![]() |
4bc446be08 | ||
![]() |
41cc7bdfdb | ||
![]() |
6e81c3e164 | ||
![]() |
2e29d0815b | ||
![]() |
1b71487e1f | ||
![]() |
1416e7b664 | ||
![]() |
29081204d1 | ||
![]() |
006d01ba42 | ||
![]() |
46dc24d835 | ||
![]() |
c9934fe8a4 | ||
![]() |
975e265f74 | ||
![]() |
c92a134b0d | ||
![]() |
3b4f066dac | ||
![]() |
b7f905787e | ||
![]() |
e3e933c6bc | ||
![]() |
1d90a76d63 | ||
![]() |
961435a243 | ||
![]() |
e9ca65c939 | ||
![]() |
753867123d | ||
![]() |
f099ebe535 | ||
![]() |
f45f70f133 | ||
![]() |
0b8aeddac6 | ||
![]() |
432ee5650b | ||
![]() |
73321b8097 | ||
![]() |
022a944367 | ||
![]() |
026ef9aae4 | ||
![]() |
a611b0bc82 | ||
![]() |
449b43762e | ||
![]() |
6ea6b4258d | ||
![]() |
48f6ca8c3a | ||
![]() |
c6d2878c1a | ||
![]() |
b34bf5d52b | ||
![]() |
608bd43604 | ||
![]() |
4c48f6460d | ||
![]() |
1331fa19f6 | ||
![]() |
dfdb284e16 | ||
![]() |
d8f41a5c0f | ||
![]() |
b9e415d19c | ||
![]() |
c82a8cc526 | ||
![]() |
75dc537e44 | ||
![]() |
cf88db44b5 | ||
![]() |
16856a0160 | ||
![]() |
d752f8e142 | ||
![]() |
d2467c320d | ||
![]() |
0d31128a44 | ||
![]() |
1ac18eac20 | ||
![]() |
526466dd09 | ||
![]() |
e7f5059fe4 | ||
![]() |
d7ac050f4b | ||
![]() |
c7edafb729 | ||
![]() |
dff4a3833f | ||
![]() |
0782a4573a | ||
![]() |
af66a09bde | ||
![]() |
436bec9fd9 | ||
![]() |
99c80a2c8b | ||
![]() |
295ce9db09 | ||
![]() |
44c1ce5e6a | ||
![]() |
144ecff849 | ||
![]() |
350095ce6e | ||
![]() |
e09bf35b28 | ||
![]() |
99c20f523e | ||
![]() |
e3b8da2a49 | ||
![]() |
a020a2d49d | ||
![]() |
930b159885 | ||
![]() |
5ad8fb7268 | ||
![]() |
2aedf3e791 | ||
![]() |
473b6b43b4 | ||
![]() |
d29770eeaa | ||
![]() |
040c3bafab | ||
![]() |
05767b026f | ||
![]() |
a83d5d60bd | ||
![]() |
ff2b58e299 | ||
![]() |
4417e37ede | ||
![]() |
79c95b6919 | ||
![]() |
1f6ab6a556 | ||
![]() |
6b0d30bb85 | ||
![]() |
447bc089b9 | ||
![]() |
fc4e5b476b | ||
![]() |
d58ac083f3 | ||
![]() |
a123c3c7d2 | ||
![]() |
9e6b8c9f48 | ||
![]() |
22fee5a383 | ||
![]() |
7365d142a3 | ||
![]() |
8b227fa9af | ||
![]() |
8c3da54c7d | ||
![]() |
acf1721b98 | ||
![]() |
f91f450141 | ||
![]() |
cd3616a463 | ||
![]() |
d35fa1db41 | ||
![]() |
e8deca84e0 | ||
![]() |
8385f93cea | ||
![]() |
2118c3dbfa | ||
![]() |
a002797d52 | ||
![]() |
1d053e0d1d | ||
![]() |
0aa65c7a6b | ||
![]() |
794feb83df | ||
![]() |
2c7df6795e | ||
![]() |
b3916cbf2b | ||
![]() |
57fe918cf8 | ||
![]() |
4912ff3ec2 | ||
![]() |
f40d17047d | ||
![]() |
2807c6aff0 | ||
![]() |
de892cb66c | ||
![]() |
37024d899c | ||
![]() |
137f55bf28 | ||
![]() |
e549f84532 | ||
![]() |
dfa9f4bc58 | ||
![]() |
e6872a4149 | ||
![]() |
f4f6e17d45 | ||
![]() |
4d4af12c6f | ||
![]() |
477397bc98 | ||
![]() |
18cca64c81 | ||
![]() |
0e5807bbcb | ||
![]() |
8eb56beb3a | ||
![]() |
ee0c2835c5 | ||
![]() |
90d04072b7 | ||
![]() |
52e1589a52 | ||
![]() |
eebd7c275d | ||
![]() |
a67bbfe745 | ||
![]() |
104c34f906 | ||
![]() |
dc2edc762c | ||
![]() |
2e02acdc83 | ||
![]() |
83f266c44c | ||
![]() |
f24200db2c | ||
![]() |
e28b57e371 | ||
![]() |
e5851e52b1 | ||
![]() |
f55908bc48 | ||
![]() |
b93c4cf378 | ||
![]() |
1e0c78b970 | ||
![]() |
76e1af0e02 | ||
![]() |
c3272d4917 | ||
![]() |
50f5d14b11 | ||
![]() |
d14a0e4ff9 | ||
![]() |
fb675de30d | ||
![]() |
25f70d4ca4 | ||
![]() |
02de234ef0 | ||
![]() |
f5df47ec6e | ||
![]() |
b9226c367c | ||
![]() |
3214629601 | ||
![]() |
072044e28f | ||
![]() |
e080290ba4 | ||
![]() |
69505b4e9b | ||
![]() |
f4ddd7dc44 | ||
![]() |
b0cd092b7f | ||
![]() |
71d1fff90a | ||
![]() |
0cfbfc9904 | ||
![]() |
2d0130f80f | ||
![]() |
c1e1c1443f | ||
![]() |
68bf1d7867 | ||
![]() |
600db7d754 | ||
![]() |
ef7b8756c0 | ||
![]() |
0b28399638 | ||
![]() |
ac6dc5d3eb | ||
![]() |
89b90dcfec | ||
![]() |
fd836d891b | ||
![]() |
976e8babbe | ||
![]() |
2520dbcf0a | ||
![]() |
430bfb4944 | ||
![]() |
08d51bf232 | ||
![]() |
cb9e585b8e | ||
![]() |
641d316484 | ||
![]() |
2b714714e1 | ||
![]() |
69a24e6a1e | ||
![]() |
5b9be57ac3 | ||
![]() |
e89c571de7 | ||
![]() |
209404239b | ||
![]() |
4e3bdb560c | ||
![]() |
86b614afcd | ||
![]() |
cfc39d84b7 | ||
![]() |
d11d77e581 | ||
![]() |
bf410cb85e | ||
![]() |
2e126aeb7e | ||
![]() |
dfbc52ce56 |
@@ -1,5 +1,8 @@
|
|||||||
version: 2.1
|
version: 2.1
|
||||||
|
|
||||||
|
orbs:
|
||||||
|
apple: ml-explore/pr-approval@0.1.0
|
||||||
|
|
||||||
parameters:
|
parameters:
|
||||||
nightly_build:
|
nightly_build:
|
||||||
type: boolean
|
type: boolean
|
||||||
@@ -7,6 +10,9 @@ parameters:
|
|||||||
weekly_build:
|
weekly_build:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
test_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
@@ -26,18 +32,28 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install --upgrade pybind11[global]
|
pip install --upgrade pybind11[global]
|
||||||
|
pip install pybind11-stubgen
|
||||||
pip install numpy
|
pip install numpy
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
- run:
|
- run:
|
||||||
name: Build python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
|
||||||
- run:
|
- run:
|
||||||
name: Run the python tests
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests
|
python3 setup.py generate_stubs
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
python3 -m unittest discover python/tests -v
|
||||||
|
# TODO: Reenable when extension api becomes stable
|
||||||
|
# - run:
|
||||||
|
# name: Build example extension
|
||||||
|
# command: |
|
||||||
|
# cd examples/extensions && python3 -m pip install .
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
@@ -47,154 +63,180 @@ jobs:
|
|||||||
command: ./build/tests/tests
|
command: ./build/tests/tests
|
||||||
|
|
||||||
mac_build_and_test:
|
mac_build_and_test:
|
||||||
machine: true
|
macos:
|
||||||
resource_class: ml-explore/m-builder
|
xcode: "15.2.0"
|
||||||
|
resource_class: macos.m1.large.gen1
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
brew install python@3.9
|
||||||
rm -r $CONDA_PREFIX/envs/runner-env
|
python3.9 -m venv env
|
||||||
conda create -y -n runner-env python=3.9
|
source env/bin/activate
|
||||||
conda activate runner-env
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install --upgrade pybind11[global]
|
pip install --upgrade pybind11[global]
|
||||||
|
pip install pybind11-stubgen
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install torch
|
pip install torch
|
||||||
|
pip install tensorflow
|
||||||
pip install unittest-xml-reporting
|
pip install unittest-xml-reporting
|
||||||
- run:
|
- run:
|
||||||
name: Build python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
source env/bin/activate
|
||||||
conda activate runner-env
|
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
|
|
||||||
- run:
|
- run:
|
||||||
name: Run the python tests
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
source env/bin/activate
|
||||||
conda activate runner-env
|
python setup.py generate_stubs
|
||||||
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
- run:
|
||||||
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
|
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
|
# TODO: Reenable when extension api becomes stable
|
||||||
|
# - run:
|
||||||
|
# name: Build example extension
|
||||||
|
# command: |
|
||||||
|
# cd examples/extensions && python3.11 -m pip install .
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
|
- run:
|
||||||
|
name: Build CPP only
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
mkdir -p build && cd build && cmake .. && make -j
|
||||||
|
- run:
|
||||||
|
name: Run CPP tests
|
||||||
|
command: |
|
||||||
|
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||||
|
DEVICE=cpu ./build/tests/tests
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
machine: true
|
|
||||||
resource_class: ml-explore/m-builder
|
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: "3.9"
|
||||||
macos_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "14"
|
default: "15.2.0"
|
||||||
|
build_env:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
macos:
|
||||||
|
xcode: << parameters.xcode_version >>
|
||||||
|
resource_class: macos.m1.large.gen1
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
brew install python@<< parameters.python_version >>
|
||||||
rm -r $CONDA_PREFIX/envs/runner-env
|
python<< parameters.python_version >> -m venv env
|
||||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
source env/bin/activate
|
||||||
conda activate runner-env
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install --upgrade pybind11[global]
|
pip install --upgrade pybind11[global]
|
||||||
|
pip install --upgrade setuptools
|
||||||
|
pip install pybind11-stubgen
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install twine
|
pip install twine
|
||||||
|
pip install build
|
||||||
- run:
|
- run:
|
||||||
name: Build pacakge
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
source env/bin/activate
|
||||||
conda activate runner-env
|
DEV_RELEASE=1 \
|
||||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
|
||||||
PYPI_RELEASE=1 \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||||
python setup.py bdist_wheel
|
pip install . -v
|
||||||
twine upload dist/* --repository mlx
|
- run:
|
||||||
|
name: Generate package stubs
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python setup.py generate_stubs
|
||||||
|
- run:
|
||||||
|
name: Build Python package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
<< parameters.build_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||||
|
python -m build -w
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload dist/*
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: dist/
|
path: dist/
|
||||||
|
|
||||||
build_dev_release:
|
build_linux_test_release:
|
||||||
machine: true
|
|
||||||
resource_class: ml-explore/m-builder
|
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: "3.9"
|
||||||
macos_version:
|
extra_env:
|
||||||
type: string
|
type: string
|
||||||
default: "14"
|
default: "DEV_RELEASE=1"
|
||||||
|
docker:
|
||||||
|
- image: ubuntu:20.04
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Build wheel
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
PYTHON=python<< parameters.python_version >>
|
||||||
rm -r $CONDA_PREFIX/envs/runner-env
|
apt-get update
|
||||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
apt-get upgrade -y
|
||||||
conda activate runner-env
|
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||||
|
apt-get install -y apt-utils
|
||||||
|
apt-get install -y software-properties-common
|
||||||
|
add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
|
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||||
|
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
apt-get install -y build-essential git
|
||||||
|
$PYTHON -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install --upgrade pybind11[global]
|
pip install --upgrade pybind11[global]
|
||||||
|
pip install --upgrade setuptools
|
||||||
|
pip install pybind11-stubgen
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install twine
|
pip install auditwheel
|
||||||
- run:
|
pip install patchelf
|
||||||
name: Build pacakge
|
pip install build
|
||||||
command: |
|
<< parameters.extra_env >> \
|
||||||
eval "$(conda shell.bash hook)"
|
|
||||||
conda activate runner-env
|
|
||||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
|
||||||
DEV_RELEASE=1 \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||||
python setup.py bdist_wheel
|
pip install . -v
|
||||||
twine upload dist/* --repository mlx
|
python setup.py generate_stubs
|
||||||
- store_artifacts:
|
<< parameters.extra_env >> \
|
||||||
path: dist/
|
|
||||||
|
|
||||||
build_package:
|
|
||||||
machine: true
|
|
||||||
resource_class: ml-explore/m-builder
|
|
||||||
parameters:
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
default: "3.9"
|
|
||||||
macos_version:
|
|
||||||
type: string
|
|
||||||
default: "14"
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
eval "$(conda shell.bash hook)"
|
|
||||||
rm -r $CONDA_PREFIX/envs/runner-env
|
|
||||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
|
||||||
conda activate runner-env
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install --upgrade pybind11[global]
|
|
||||||
pip install numpy
|
|
||||||
pip install twine
|
|
||||||
- run:
|
|
||||||
name: Build pacakge
|
|
||||||
command: |
|
|
||||||
eval "$(conda shell.bash hook)"
|
|
||||||
conda activate runner-env
|
|
||||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||||
python setup.py bdist_wheel
|
python -m build --wheel
|
||||||
|
auditwheel show dist/*
|
||||||
|
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: dist/
|
path: wheelhouse/
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
|
- matches:
|
||||||
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
- not: << pipeline.parameters.weekly_build >>
|
||||||
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- linux_build_and_test
|
|
||||||
- mac_build_and_test
|
- mac_build_and_test
|
||||||
|
- linux_build_and_test
|
||||||
- build_release:
|
- build_release:
|
||||||
filters:
|
filters:
|
||||||
tags:
|
tags:
|
||||||
@@ -204,20 +246,53 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||||
macos_version: ["13", "14"]
|
xcode_version: ["14.3.1", "15.2.0"]
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
prb:
|
||||||
|
when:
|
||||||
|
matches:
|
||||||
|
pattern: "^pull/\\d+(/head)?$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
|
jobs:
|
||||||
|
- hold:
|
||||||
|
type: approval
|
||||||
|
- apple/authenticate:
|
||||||
|
context: pr-approval
|
||||||
|
- mac_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
|
- linux_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when: << pipeline.parameters.nightly_build >>
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.nightly_build >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_package:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||||
macos_version: ["13", "14"]
|
xcode_version: ["14.3.1", "15.2.0"]
|
||||||
weekly_build:
|
weekly_build:
|
||||||
when: << pipeline.parameters.weekly_build >>
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.weekly_build >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_dev_release:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||||
macos_version: ["13", "14"]
|
xcode_version: ["14.3.1", "15.2.0"]
|
||||||
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
linux_test_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.test_release >>
|
||||||
|
jobs:
|
||||||
|
- build_linux_test_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||||
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
---
|
||||||
|
name: Bug report
|
||||||
|
about: Create a report about an issue you've encountered
|
||||||
|
title: "[BUG] "
|
||||||
|
labels: ''
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Describe the bug**
|
||||||
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
|
**To Reproduce**
|
||||||
|
|
||||||
|
Include code snippet
|
||||||
|
```python
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected behavior**
|
||||||
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
**Desktop (please complete the following information):**
|
||||||
|
- OS Version: [e.g. MacOS 14.1.2]
|
||||||
|
- Version [e.g. 0.7.0]
|
||||||
|
|
||||||
|
**Additional context**
|
||||||
|
Add any other context about the problem here.
|
12
.github/pull_request_template.md
vendored
Normal file
12
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
## Proposed changes
|
||||||
|
|
||||||
|
Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
Put an `x` in the boxes that apply.
|
||||||
|
|
||||||
|
- [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
|
||||||
|
- [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes
|
||||||
|
- [ ] I have added tests that prove my fix is effective or that my feature works
|
||||||
|
- [ ] I have updated the necessary documentation (if needed)
|
20
.github/workflows/pull_request.yml
vendored
Normal file
20
.github/workflows/pull_request.yml
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check_lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.8
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pre-commit black isort clang-format
|
||||||
|
- name: Run lint
|
||||||
|
run: |
|
||||||
|
pre-commit run --all-files
|
8
.gitignore
vendored
8
.gitignore
vendored
@@ -6,11 +6,16 @@ __pycache__/
|
|||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
# tensor files
|
||||||
|
*.safe
|
||||||
|
*.safetensors
|
||||||
|
|
||||||
# Metal libraries
|
# Metal libraries
|
||||||
*.metallib
|
*.metallib
|
||||||
venv/
|
venv/
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
|
python/mlx/core
|
||||||
python/mlx/share
|
python/mlx/share
|
||||||
python/mlx/include
|
python/mlx/include
|
||||||
.Python
|
.Python
|
||||||
@@ -74,3 +79,6 @@ build/
|
|||||||
# VSCode
|
# VSCode
|
||||||
.vscode/
|
.vscode/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# Jetbrains
|
||||||
|
.cache
|
||||||
|
@@ -1,9 +1,16 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v14.0.6
|
rev: v17.0.6
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
- repo: https://github.com/psf/black
|
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||||
rev: 22.10.0
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
|
rev: 24.2.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.13.2
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
args:
|
||||||
|
- --profile=black
|
||||||
|
@@ -1,3 +1,24 @@
|
|||||||
|
# Individual Contributors
|
||||||
|
|
||||||
|
If you wish to be acknowledged for your contributions, please list your name
|
||||||
|
with a short description of your contribution(s) below. For example:
|
||||||
|
|
||||||
|
- Jane Smith: Added the `foo` and `bar` ops.
|
||||||
|
|
||||||
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||||
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support
|
||||||
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
||||||
|
|
||||||
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
</a>
|
||||||
|
|
||||||
|
# Third-Party Software
|
||||||
|
|
||||||
MLX leverages several third-party software, listed here together with
|
MLX leverages several third-party software, listed here together with
|
||||||
their license copied verbatim.
|
their license copied verbatim.
|
||||||
|
|
||||||
@@ -231,4 +252,4 @@ Unless required by applicable law or agreed to in writing, software
|
|||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
@@ -1,6 +1,6 @@
|
|||||||
cmake_minimum_required(VERSION 3.24)
|
cmake_minimum_required(VERSION 3.24)
|
||||||
|
|
||||||
project(mlx LANGUAGES CXX)
|
project(mlx LANGUAGES C CXX)
|
||||||
|
|
||||||
# ----------------------------- Setup -----------------------------
|
# ----------------------------- Setup -----------------------------
|
||||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||||
@@ -18,7 +18,34 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
|
|||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
if(NOT MLX_VERSION)
|
if(NOT MLX_VERSION)
|
||||||
set(MLX_VERSION 0.0.3)
|
set(MLX_VERSION 0.3.0)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# --------------------- Processor tests -------------------------
|
||||||
|
|
||||||
|
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||||
|
|
||||||
|
set(MLX_BUILD_ARM OFF)
|
||||||
|
|
||||||
|
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
|
|
||||||
|
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
|
||||||
|
message(FATAL_ERROR
|
||||||
|
"Building for x86_64 on macOS is not supported."
|
||||||
|
" If you are on an Apple silicon system, check the build"
|
||||||
|
" documentation for possible fixes: "
|
||||||
|
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||||
|
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||||
|
message(WARNING
|
||||||
|
"Building for x86_64 on macOS is not supported."
|
||||||
|
" If you are on an Apple silicon system, "
|
||||||
|
" make sure you are building for arm64.")
|
||||||
|
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||||
|
set(MLX_BUILD_ARM ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
else()
|
||||||
|
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Lib -----------------------------
|
# ----------------------------- Lib -----------------------------
|
||||||
@@ -37,15 +64,18 @@ endif()
|
|||||||
|
|
||||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
|
set(MLX_BUILD_METAL OFF)
|
||||||
elseif (MLX_BUILD_METAL)
|
elseif (MLX_BUILD_METAL)
|
||||||
message(STATUS "Building METAL sources")
|
message(STATUS "Building METAL sources")
|
||||||
add_compile_definitions(_METAL_)
|
add_compile_definitions(_METAL_)
|
||||||
|
|
||||||
|
# Throw an error if xcrun not found
|
||||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
OUTPUT_VARIABLE MACOS_VERSION)
|
OUTPUT_VARIABLE MACOS_VERSION
|
||||||
|
COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
|
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||||
|
|
||||||
message(STATUS "Building with SDK for MacOS version ${MACOS_VERSION}")
|
|
||||||
|
|
||||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||||
@@ -53,7 +83,7 @@ elseif (MLX_BUILD_METAL)
|
|||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "MLX requires MacOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
|
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
@@ -75,13 +105,13 @@ elseif (MLX_BUILD_METAL)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
if (ACCELERATE_LIBRARY)
|
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||||
else()
|
else()
|
||||||
message(STATUS "Accelerate not found, using default backend.")
|
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
#set(BLA_VENDOR Generic)
|
#set(BLA_VENDOR Generic)
|
||||||
find_package(BLAS REQUIRED)
|
find_package(BLAS REQUIRED)
|
||||||
@@ -93,16 +123,27 @@ else()
|
|||||||
/usr/include
|
/usr/include
|
||||||
/usr/local/include
|
/usr/local/include
|
||||||
$ENV{BLAS_HOME}/include)
|
$ENV{BLAS_HOME}/include)
|
||||||
message(STATUS ${BLAS_LIBRARIES})
|
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||||
message(STATUS ${BLAS_INCLUDE_DIRS})
|
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||||
|
find_package(LAPACK REQUIRED)
|
||||||
|
if (NOT LAPACK_FOUND)
|
||||||
|
message(FATAL_ERROR "Must have LAPACK installed")
|
||||||
|
endif()
|
||||||
|
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||||
|
/usr/include
|
||||||
|
/usr/local/include)
|
||||||
|
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||||
|
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||||
|
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||||
|
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx
|
mlx
|
||||||
PUBLIC
|
PUBLIC
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
$<INSTALL_INTERFACE:include>
|
$<INSTALL_INTERFACE:include>
|
||||||
@@ -128,6 +169,8 @@ if (MLX_BUILD_BENCHMARKS)
|
|||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------- Installation -----------------------------
|
# ----------------------------- Installation -----------------------------
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
@@ -197,4 +240,4 @@ install(
|
|||||||
install(
|
install(
|
||||||
DIRECTORY ${CMAKE_MODULE_PATH}/
|
DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
)
|
)
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
recursive-include mlx/ *
|
recursive-include mlx/ *
|
||||||
include python/src/*
|
include python/src/*
|
||||||
|
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||||
|
52
README.md
52
README.md
@@ -6,8 +6,8 @@
|
|||||||
|
|
||||||
[](https://circleci.com/gh/ml-explore/mlx)
|
[](https://circleci.com/gh/ml-explore/mlx)
|
||||||
|
|
||||||
MLX is an array framework for machine learning on Apple silicon, brought to you
|
MLX is an array framework for machine learning research on Apple silicon,
|
||||||
by Apple machine learning research.
|
brought to you by Apple machine learning research.
|
||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
@@ -16,24 +16,24 @@ Some key features of MLX include:
|
|||||||
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
||||||
that closely follow PyTorch to simplify building more complex models.
|
that closely follow PyTorch to simplify building more complex models.
|
||||||
|
|
||||||
- **Composable function transformations**: MLX has composable function
|
- **Composable function transformations**: MLX supports composable function
|
||||||
transformations for automatic differentiation, automatic vectorization,
|
transformations for automatic differentiation, automatic vectorization,
|
||||||
and computation graph optimization.
|
and computation graph optimization.
|
||||||
|
|
||||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||||
materialized when needed.
|
materialized when needed.
|
||||||
|
|
||||||
- **Dynamic graph construction**: Computation graphs in MLX are built
|
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||||
dynamically. Changing the shapes of function arguments does not trigger
|
dynamically. Changing the shapes of function arguments does not trigger
|
||||||
slow compilations, and debugging is simple and intuitive.
|
slow compilations, and debugging is simple and intuitive.
|
||||||
|
|
||||||
- **Multi-device**: Operations can run on any of the supported devices
|
- **Multi-device**: Operations can run on any of the supported devices
|
||||||
(currently, the CPU and GPU).
|
(currently the CPU and the GPU).
|
||||||
|
|
||||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||||
Operations on MLX arrays can be performed on any of the supported
|
Operations on MLX arrays can be performed on any of the supported
|
||||||
device types without moving data.
|
device types without transferring data.
|
||||||
|
|
||||||
MLX is designed by machine learning researchers for machine learning
|
MLX is designed by machine learning researchers for machine learning
|
||||||
researchers. The framework is intended to be user-friendly, but still efficient
|
researchers. The framework is intended to be user-friendly, but still efficient
|
||||||
@@ -53,7 +53,7 @@ variety of examples, including:
|
|||||||
|
|
||||||
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
|
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
|
||||||
- Large-scale text generation with
|
- Large-scale text generation with
|
||||||
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and
|
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
|
||||||
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
|
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
|
||||||
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
|
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
|
||||||
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
|
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
|
||||||
@@ -61,17 +61,25 @@ variety of examples, including:
|
|||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
See the [quick start
|
See the [quick start
|
||||||
guide](https://ml-explore.github.io/mlx/build/html/quick_start.html)
|
guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
|
||||||
in the documentation.
|
in the documentation.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPi](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||||
|
|
||||||
|
**With `pip`**:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**With `conda`**:
|
||||||
|
|
||||||
|
```
|
||||||
|
conda install -c conda-forge mlx
|
||||||
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
||||||
for more information on building the C++ and Python APIs from source.
|
for more information on building the C++ and Python APIs from source.
|
||||||
@@ -79,4 +87,28 @@ for more information on building the C++ and Python APIs from source.
|
|||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
||||||
on contributing to MLX.
|
on contributing to MLX. See the
|
||||||
|
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
||||||
|
information on building from source, and running tests.
|
||||||
|
|
||||||
|
We are grateful for all of [our
|
||||||
|
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||||
|
to MLX and wish to be acknowledged, please add your name to the list in your
|
||||||
|
pull request.
|
||||||
|
|
||||||
|
## Citing MLX
|
||||||
|
|
||||||
|
The MLX software suite was initially developed with equal contribution by Awni
|
||||||
|
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
||||||
|
MLX useful in your research and wish to cite it, please use the following
|
||||||
|
BibTex entry:
|
||||||
|
|
||||||
|
```
|
||||||
|
@software{mlx2023,
|
||||||
|
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||||
|
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||||
|
url = {https://github.com/ml-explore},
|
||||||
|
version = {0.0},
|
||||||
|
year = {2023},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
@@ -233,6 +233,20 @@ void time_gather_scatter() {
|
|||||||
TIME(single_element_add);
|
TIME(single_element_add);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void time_divmod() {
|
||||||
|
auto a = random::normal({1000});
|
||||||
|
auto b = random::normal({1000});
|
||||||
|
eval({a, b});
|
||||||
|
|
||||||
|
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
||||||
|
TIME(divmod_fused);
|
||||||
|
|
||||||
|
auto divmod_separate = [&a, &b]() {
|
||||||
|
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
||||||
|
};
|
||||||
|
TIME(divmod_separate);
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||||
time_creation_ops();
|
time_creation_ops();
|
||||||
@@ -246,4 +260,5 @@ int main() {
|
|||||||
time_matmul();
|
time_matmul();
|
||||||
time_reductions();
|
time_reductions();
|
||||||
time_gather_scatter();
|
time_gather_scatter();
|
||||||
|
time_divmod();
|
||||||
}
|
}
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
B = 8
|
B = 8
|
||||||
|
@@ -1,13 +1,14 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
import argparse
|
||||||
import mlx.core as mx
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
device_name = device_name.decode("utf-8").strip("\n")
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
@@ -165,13 +166,13 @@ if __name__ == "__main__":
|
|||||||
dtypes = ("float32", "float16")
|
dtypes = ("float32", "float16")
|
||||||
transposes = ("nn", "nt", "tn")
|
transposes = ("nn", "nt", "tn")
|
||||||
shapes = (
|
shapes = (
|
||||||
|
(16, 234, 768, 3072),
|
||||||
|
(1, 64, 64, 25344),
|
||||||
(16, 1024, 1024, 1024),
|
(16, 1024, 1024, 1024),
|
||||||
(1, 1024, 1024, 2048),
|
(1, 1024, 1024, 2048),
|
||||||
(4, 1024, 1024, 4096),
|
(4, 1024, 1024, 4096),
|
||||||
(4, 1024, 4096, 1024),
|
(4, 1024, 4096, 1024),
|
||||||
(1, 4096, 4096, 4096),
|
(1, 4096, 4096, 4096),
|
||||||
(15, 1023, 1023, 1023),
|
|
||||||
(17, 1025, 1025, 1025),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
|
@@ -1,14 +1,14 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
import argparse
|
||||||
import mlx.core as mx
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
results_dir = "./results"
|
results_dir = "./results"
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
|
|||||||
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
||||||
|
|
||||||
|
|
||||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
|
||||||
np_dtype = getattr(np, dtype)
|
np_dtype = getattr(np, dtype)
|
||||||
mlx_gb_s = []
|
mlx_gb_s = []
|
||||||
mlx_gflops = []
|
mlx_gflops = []
|
||||||
@@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
|||||||
ax.legend()
|
ax.legend()
|
||||||
|
|
||||||
|
|
||||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
|
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
||||||
np_dtype = getattr(np, dtype)
|
np_dtype = getattr(np, dtype)
|
||||||
mlx_gb_s = []
|
mlx_gb_s = []
|
||||||
mlx_gflops = []
|
mlx_gflops = []
|
||||||
|
@@ -4,8 +4,10 @@ import argparse
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def int_or_list(x):
|
def int_or_list(x):
|
||||||
@@ -22,6 +24,16 @@ def none_or_list(x):
|
|||||||
return [int(xi) for xi in x.split(",")]
|
return [int(xi) for xi in x.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
def dtype_from_str(x):
|
||||||
|
if x == "":
|
||||||
|
return mx.float32
|
||||||
|
else:
|
||||||
|
dt = getattr(mx, x)
|
||||||
|
if not isinstance(dt, mx.Dtype):
|
||||||
|
raise ValueError(f"{x} is not an mlx dtype")
|
||||||
|
return dt
|
||||||
|
|
||||||
|
|
||||||
def bench(f, *args):
|
def bench(f, *args):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
f(*args)
|
f(*args)
|
||||||
@@ -48,6 +60,63 @@ def matmul(x, y):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(
|
||||||
|
mx.quantized_matmul(
|
||||||
|
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
quant_matmul = {
|
||||||
|
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
|
||||||
|
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
|
||||||
|
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
|
||||||
|
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
||||||
|
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
||||||
|
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
||||||
|
"quant_matmul_128_2": partial(
|
||||||
|
_quant_matmul, transpose=False, group_size=128, bits=2
|
||||||
|
),
|
||||||
|
"quant_matmul_128_4": partial(
|
||||||
|
_quant_matmul, transpose=False, group_size=128, bits=4
|
||||||
|
),
|
||||||
|
"quant_matmul_128_8": partial(
|
||||||
|
_quant_matmul, transpose=False, group_size=128, bits=8
|
||||||
|
),
|
||||||
|
"quant_matmul_t_32_2": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=32, bits=2
|
||||||
|
),
|
||||||
|
"quant_matmul_t_32_4": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=32, bits=4
|
||||||
|
),
|
||||||
|
"quant_matmul_t_32_8": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=32, bits=8
|
||||||
|
),
|
||||||
|
"quant_matmul_t_64_2": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=64, bits=2
|
||||||
|
),
|
||||||
|
"quant_matmul_t_64_4": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=64, bits=4
|
||||||
|
),
|
||||||
|
"quant_matmul_t_64_8": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=64, bits=8
|
||||||
|
),
|
||||||
|
"quant_matmul_t_128_2": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=128, bits=2
|
||||||
|
),
|
||||||
|
"quant_matmul_t_128_4": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=128, bits=4
|
||||||
|
),
|
||||||
|
"quant_matmul_t_128_8": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=128, bits=8
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def conv1d(x, y):
|
def conv1d(x, y):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@@ -95,7 +164,77 @@ def softmax_fused(axis, x):
|
|||||||
def relu(x):
|
def relu(x):
|
||||||
y = x
|
y = x
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
y = mx.maximum(y, 0)
|
y = nn.relu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def leaky_relu(x: mx.array):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.leaky_relu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def prelu(x: mx.array):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.prelu(y, mx.ones(1))
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def softplus(x: mx.array):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.softplus(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def mish(x: mx.array):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.mish(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def leaky_relu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.leaky_relu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def elu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.elu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def relu6(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.relu6(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def softplus(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.softplus(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def celu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.celu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def log_sigmoid(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.log_sigmoid(y)
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
@@ -130,6 +269,13 @@ def linear(w, b, x):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_fused(w, b, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
def rope(x):
|
def rope(x):
|
||||||
*_, N, D = x.shape
|
*_, N, D = x.shape
|
||||||
ys = []
|
ys = []
|
||||||
@@ -180,6 +326,20 @@ def topk(axis, x):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def step_function(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.step(x)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def selu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.selu(x)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||||
@@ -211,9 +371,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fused", action="store_true", help="Use fused functions where possible"
|
"--fused", action="store_true", help="Use fused functions where possible"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
||||||
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -230,11 +388,15 @@ if __name__ == "__main__":
|
|||||||
mx.set_default_device(mx.cpu)
|
mx.set_default_device(mx.cpu)
|
||||||
else:
|
else:
|
||||||
mx.set_default_device(mx.gpu)
|
mx.set_default_device(mx.gpu)
|
||||||
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
|
|
||||||
args.dtype
|
types = args.dtype
|
||||||
]
|
if not types:
|
||||||
|
types = [mx.float32]
|
||||||
|
if len(types) < len(args.size):
|
||||||
|
types = types + [types[0]] * (len(args.size) - len(types))
|
||||||
|
|
||||||
xs = []
|
xs = []
|
||||||
for size in args.size:
|
for size, dtype in zip(args.size, types):
|
||||||
xs.append(mx.random.normal(size).astype(dtype))
|
xs.append(mx.random.normal(size).astype(dtype))
|
||||||
for i, t in enumerate(args.transpose):
|
for i, t in enumerate(args.transpose):
|
||||||
if t is None:
|
if t is None:
|
||||||
@@ -250,8 +412,14 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "matmul":
|
elif args.benchmark == "matmul":
|
||||||
print(bench(matmul, *xs))
|
print(bench(matmul, *xs))
|
||||||
|
|
||||||
|
elif args.benchmark.startswith("quant_matmul"):
|
||||||
|
print(bench(quant_matmul[args.benchmark], *xs))
|
||||||
|
|
||||||
elif args.benchmark == "linear":
|
elif args.benchmark == "linear":
|
||||||
print(bench(linear, *xs))
|
if args.fused:
|
||||||
|
print(bench(linear_fused, *xs))
|
||||||
|
else:
|
||||||
|
print(bench(linear, *xs))
|
||||||
|
|
||||||
elif args.benchmark == "sum_axis":
|
elif args.benchmark == "sum_axis":
|
||||||
print(bench(reduction, "sum", axis, x))
|
print(bench(reduction, "sum", axis, x))
|
||||||
@@ -277,6 +445,26 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu":
|
elif args.benchmark == "relu":
|
||||||
print(bench(relu, x))
|
print(bench(relu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "elu":
|
||||||
|
print(bench(elu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu6":
|
||||||
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "celu":
|
||||||
|
print(bench(celu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "log_sigmoid":
|
||||||
|
print(bench(log_sigmoid, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "leaky_relu":
|
||||||
|
print(bench(leaky_relu, x))
|
||||||
|
elif args.benchmark == "prelu":
|
||||||
|
print(bench(prelu, x))
|
||||||
|
elif args.benchmark == "softplus":
|
||||||
|
print(bench(softplus, x))
|
||||||
|
elif args.benchmark == "mish":
|
||||||
|
print(bench(mish, x))
|
||||||
elif args.benchmark == "scalar_mul":
|
elif args.benchmark == "scalar_mul":
|
||||||
print(bench(scalar_mult, x))
|
print(bench(scalar_mult, x))
|
||||||
|
|
||||||
@@ -311,5 +499,11 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "topk":
|
elif args.benchmark == "topk":
|
||||||
print(bench(topk, axis, x))
|
print(bench(topk, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "step":
|
||||||
|
print(bench(step_function, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "selu":
|
||||||
|
print(bench(selu, x))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError("Unknown benchmark")
|
||||||
|
@@ -22,6 +22,16 @@ def none_or_list(x):
|
|||||||
return [int(xi) for xi in x.split(",")]
|
return [int(xi) for xi in x.split(",")]
|
||||||
|
|
||||||
|
|
||||||
|
def dtype_from_str(x):
|
||||||
|
if x == "":
|
||||||
|
return torch.float32
|
||||||
|
else:
|
||||||
|
dt = getattr(torch, x)
|
||||||
|
if not isinstance(dt, torch.dtype):
|
||||||
|
raise ValueError(f"{x} is not a torch dtype")
|
||||||
|
return dt
|
||||||
|
|
||||||
|
|
||||||
def bench(f, *args):
|
def bench(f, *args):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
f(*args)
|
f(*args)
|
||||||
@@ -115,6 +125,70 @@ def relu(x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def leaky_relu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.leaky_relu(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def elu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.elu(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def celu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.celu(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def relu6(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.relu6(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def softplus(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.softplus(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_sigmoid(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.logsigmoid(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
y = x
|
||||||
|
for _ in range(100):
|
||||||
|
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
y = x
|
||||||
|
for _ in range(100):
|
||||||
|
return torch.nn.functional.mish(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def scalar_mult(x):
|
def scalar_mult(x):
|
||||||
y = x
|
y = x
|
||||||
@@ -209,6 +283,14 @@ def topk(axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def selu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.selu(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||||
@@ -240,7 +322,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fused", action="store_true", help="Use fused functions where possible"
|
"--fused", action="store_true", help="Use fused functions where possible"
|
||||||
)
|
)
|
||||||
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
|
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -255,9 +337,15 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "cpu" if args.cpu else "mps"
|
||||||
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
|
|
||||||
|
types = args.dtype
|
||||||
|
if not types:
|
||||||
|
types = [torch.float32]
|
||||||
|
if len(types) < len(args.size):
|
||||||
|
types = types + [types[0]] * (len(args.size) - len(types))
|
||||||
|
|
||||||
xs = []
|
xs = []
|
||||||
for size in args.size:
|
for size, dtype in zip(args.size, types):
|
||||||
xs.append(torch.randn(*size).to(device).to(dtype))
|
xs.append(torch.randn(*size).to(device).to(dtype))
|
||||||
for i, t in enumerate(args.transpose):
|
for i, t in enumerate(args.transpose):
|
||||||
if t is None:
|
if t is None:
|
||||||
@@ -302,6 +390,28 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu":
|
elif args.benchmark == "relu":
|
||||||
print(bench(relu, x))
|
print(bench(relu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "leaky_relu":
|
||||||
|
print(bench(leaky_relu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "elu":
|
||||||
|
print(bench(elu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu6":
|
||||||
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "softplus":
|
||||||
|
print(bench(softplus, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "celu":
|
||||||
|
print(bench(celu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "log_sigmoid":
|
||||||
|
print(bench(log_sigmoid, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "prelu":
|
||||||
|
print(bench(prelu, x))
|
||||||
|
elif args.benchmark == "mish":
|
||||||
|
print(bench(mish, x))
|
||||||
elif args.benchmark == "scalar_mul":
|
elif args.benchmark == "scalar_mul":
|
||||||
print(bench(scalar_mult, x))
|
print(bench(scalar_mult, x))
|
||||||
|
|
||||||
|
@@ -62,7 +62,7 @@ def make_predicate(positive_filter, negative_filter):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
|
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
||||||
)
|
)
|
||||||
@@ -80,10 +80,8 @@ if __name__ == "__main__":
|
|||||||
_filter = make_predicate(args.filter, args.negative_filter)
|
_filter = make_predicate(args.filter, args.negative_filter)
|
||||||
|
|
||||||
if args.mlx_dtypes:
|
if args.mlx_dtypes:
|
||||||
compare_filtered = (
|
compare_filtered = lambda x: (
|
||||||
lambda x: compare_mlx_dtypes(
|
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
|
||||||
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
|
|
||||||
)
|
|
||||||
if _filter(x)
|
if _filter(x)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
@@ -125,6 +123,14 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
||||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
||||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
|
||||||
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
||||||
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
||||||
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
||||||
@@ -193,6 +199,27 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
|
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
|
||||||
compare_filtered("relu --size 32x16x1024")
|
compare_filtered("relu --size 32x16x1024")
|
||||||
compare_filtered("relu --size 32x16x1024 --cpu")
|
compare_filtered("relu --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("leaky_relu --size 32x16x1024")
|
||||||
|
compare_filtered("leaky_relu --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("elu --size 32x16x1024")
|
||||||
|
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("relu6 --size 32x16x1024")
|
||||||
|
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("softplus --size 32x16x1024")
|
||||||
|
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("celu --size 32x16x1024")
|
||||||
|
compare_filtered("celu --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("log_sigmoid --size 32x16x1024")
|
||||||
|
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("step --size 32x16x1024")
|
||||||
|
compare_filtered("step --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("selu --size 32x16x1024")
|
||||||
|
compare_filtered("selu --size 32x16x1024 --cpu")
|
||||||
|
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
|
||||||
|
compare_filtered("mish --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("prelu --size 32x16x1024")
|
||||||
|
compare_filtered("prelu --size 32x16x1024 --cpu")
|
||||||
|
|
||||||
compare_filtered("scalar_mul --size 32x16x1024")
|
compare_filtered("scalar_mul --size 32x16x1024")
|
||||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||||
compare_filtered("cross_entropy --size 256x1024")
|
compare_filtered("cross_entropy --size 256x1024")
|
||||||
|
53
benchmarks/python/gather_bench.py
Normal file
53
benchmarks/python/gather_bench.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||||
|
def gather(x, idx):
|
||||||
|
mx.eval(x[idx])
|
||||||
|
|
||||||
|
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
|
||||||
|
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||||
|
|
||||||
|
runtime = measure_runtime(gather, x=x, idx=idx)
|
||||||
|
print(f"MLX: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_gather_torch(x_shape, idx_shape, device):
|
||||||
|
def gather(x, idx, device):
|
||||||
|
_ = x[idx]
|
||||||
|
if device == torch.device("mps"):
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
|
||||||
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
|
||||||
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device("mps")
|
||||||
|
|
||||||
|
idx_shapes = [(1_000_000,), (100_000,), ()]
|
||||||
|
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
|
||||||
|
|
||||||
|
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
|
||||||
|
print("=" * 20)
|
||||||
|
print(f"X {x_shape}, Indices {idx_shape}")
|
||||||
|
benchmark_gather_mlx(x_shape, idx_shape)
|
||||||
|
benchmark_gather_torch(x_shape, idx_shape, device=device)
|
@@ -1,198 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
from flax import linen as nn
|
|
||||||
|
|
||||||
|
|
||||||
class RoPE(nn.Module):
|
|
||||||
dims: int
|
|
||||||
traditional: bool = False
|
|
||||||
|
|
||||||
def _compute_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., : self.dims // 2]
|
|
||||||
x2 = x[..., self.dims // 2 : self.dims]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
|
||||||
else:
|
|
||||||
rx = jnp.concatenate([rx1, rx2], axis=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"RoPE doesn't implement partial traditional application"
|
|
||||||
)
|
|
||||||
|
|
||||||
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_cos_sin_theta(
|
|
||||||
N: int,
|
|
||||||
D: int,
|
|
||||||
offset: int = 0,
|
|
||||||
base: float = 10000,
|
|
||||||
dtype=jnp.float32,
|
|
||||||
):
|
|
||||||
D = D // 2
|
|
||||||
positions = jnp.arange(offset, N, dtype=dtype)
|
|
||||||
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
|
|
||||||
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
|
|
||||||
costheta = jnp.cos(theta)
|
|
||||||
sintheta = jnp.sin(theta)
|
|
||||||
|
|
||||||
return costheta, sintheta
|
|
||||||
|
|
||||||
@nn.compact
|
|
||||||
def __call__(self, x, offset: int = 0):
|
|
||||||
shape = x.shape
|
|
||||||
x = x.reshape((-1, shape[-2], shape[-1]))
|
|
||||||
N = x.shape[1] + offset
|
|
||||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
|
||||||
N, self.dims, offset=offset, dtype=x.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
rope = (
|
|
||||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
|
||||||
)
|
|
||||||
rx = rope(costheta, sintheta, x)
|
|
||||||
|
|
||||||
return rx.reshape(shape)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
|
||||||
dims: int
|
|
||||||
num_heads: int
|
|
||||||
dtype: jnp.dtype
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
num_heads = self.num_heads
|
|
||||||
dims = self.dims
|
|
||||||
|
|
||||||
self.rope = RoPE(dims // num_heads, True)
|
|
||||||
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
|
||||||
queries = self.query_proj(queries)
|
|
||||||
keys = self.key_proj(keys)
|
|
||||||
values = self.value_proj(values)
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
|
||||||
B, L, D = queries.shape
|
|
||||||
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
|
||||||
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
|
||||||
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
key_cache, value_cache = cache
|
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
||||||
keys = jnp.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = jnp.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
|
||||||
queries = self.rope(queries)
|
|
||||||
keys = self.rope(keys)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
|
||||||
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
|
|
||||||
if mask is not None:
|
|
||||||
scores = scores + mask
|
|
||||||
scores = jax.nn.softmax(scores, axis=-1)
|
|
||||||
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
|
|
||||||
|
|
||||||
return self.out_proj(values_hat), (keys, values)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaEncoderLayer(nn.Module):
|
|
||||||
dims: int
|
|
||||||
mlp_dims: int
|
|
||||||
num_heads: int
|
|
||||||
dtype: jnp.dtype
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
dims = self.dims
|
|
||||||
mlp_dims = self.mlp_dims
|
|
||||||
num_heads = self.num_heads
|
|
||||||
|
|
||||||
self.attention = LlamaAttention(dims, num_heads, dtype)
|
|
||||||
|
|
||||||
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
|
|
||||||
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
|
|
||||||
|
|
||||||
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
|
|
||||||
def __call__(self, x, mask=None, cache=None):
|
|
||||||
y = self.norm1(x)
|
|
||||||
y, cache = self.attention(y, y, y, mask, cache)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.norm2(x)
|
|
||||||
a = self.linear1(y)
|
|
||||||
b = self.linear2(y)
|
|
||||||
y = jax.nn.silu(a) * b
|
|
||||||
y = self.linear3(y)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
return x, cache
|
|
||||||
|
|
||||||
|
|
||||||
def measure(model, x, cache):
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
jax.block_until_ready((y, c))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
jax.block_until_ready((y, c))
|
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
return (end - start) * 1000 / 5
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
H = 32
|
|
||||||
D = 4096
|
|
||||||
F = 43 * 256
|
|
||||||
C = 1000
|
|
||||||
dtype = jnp.float16
|
|
||||||
|
|
||||||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
|
|
||||||
|
|
||||||
x = jax.random.normal(k1, (1, 1, D), dtype)
|
|
||||||
cache = [
|
|
||||||
jax.random.normal(k2, [1, H, C, D // H], dtype),
|
|
||||||
jax.random.normal(k3, [1, H, C, D // H], dtype),
|
|
||||||
]
|
|
||||||
|
|
||||||
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
|
|
||||||
params = layer.init(k4, x, mask=None, cache=cache)["params"]
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def model_fn(x, mask, cache):
|
|
||||||
return layer.apply({"params": params}, x, mask=mask, cache=cache)
|
|
||||||
|
|
||||||
T = measure(model_fn, x, cache)
|
|
||||||
|
|
||||||
print("Time per layer per token:", T, "ms")
|
|
||||||
print("Lower bound total time per token:", T * 32, "ms")
|
|
@@ -1,118 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx.utils
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
|
||||||
def __init__(self, dims: int, num_heads: int):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.rope = nn.RoPE(dims // num_heads, True)
|
|
||||||
self.query_proj = nn.Linear(dims, dims, False)
|
|
||||||
self.key_proj = nn.Linear(dims, dims, False)
|
|
||||||
self.value_proj = nn.Linear(dims, dims, False)
|
|
||||||
self.out_proj = nn.Linear(dims, dims, False)
|
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
|
||||||
queries = self.query_proj(queries)
|
|
||||||
keys = self.key_proj(keys)
|
|
||||||
values = self.value_proj(values)
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
|
||||||
B, L, D = queries.shape
|
|
||||||
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
|
||||||
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
|
||||||
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
key_cache, value_cache = cache
|
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
|
||||||
queries = self.rope(queries)
|
|
||||||
keys = self.rope(keys)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
|
|
||||||
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
|
|
||||||
if mask is not None:
|
|
||||||
scores = scores + mask
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
|
||||||
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
|
|
||||||
|
|
||||||
return self.out_proj(values_hat), (keys, values)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaEncoderLayer(nn.Module):
|
|
||||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.attention = LlamaAttention(dims, num_heads)
|
|
||||||
|
|
||||||
self.norm1 = nn.RMSNorm(dims)
|
|
||||||
self.norm2 = nn.RMSNorm(dims)
|
|
||||||
|
|
||||||
self.linear1 = nn.Linear(dims, mlp_dims, False)
|
|
||||||
self.linear2 = nn.Linear(dims, mlp_dims, False)
|
|
||||||
self.linear3 = nn.Linear(mlp_dims, dims, False)
|
|
||||||
|
|
||||||
def __call__(self, x, mask=None, cache=None):
|
|
||||||
y = self.norm1(x)
|
|
||||||
y, cache = self.attention(y, y, y, mask, cache)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.norm2(x)
|
|
||||||
a = self.linear1(y)
|
|
||||||
b = self.linear2(y)
|
|
||||||
y = a * mx.sigmoid(a) * b
|
|
||||||
y = self.linear3(y)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
return x, cache
|
|
||||||
|
|
||||||
|
|
||||||
def measure(model, x, cache):
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
mx.eval(y, c)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
rs = []
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
rs.append((y, c))
|
|
||||||
mx.eval(rs)
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
return (end - start) * 1000 / 5
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
H = 32
|
|
||||||
D = 4096
|
|
||||||
F = 43 * 256
|
|
||||||
C = 1000
|
|
||||||
mx.set_default_device(mx.gpu)
|
|
||||||
dtype = mx.float16
|
|
||||||
|
|
||||||
layer = LlamaEncoderLayer(D, F, H)
|
|
||||||
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
|
|
||||||
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
|
|
||||||
x = mx.random.normal([1, 1, D], dtype=dtype)
|
|
||||||
cache = [
|
|
||||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
|
||||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
|
||||||
]
|
|
||||||
mx.eval(x, cache)
|
|
||||||
|
|
||||||
T = measure(layer, x, cache)
|
|
||||||
|
|
||||||
print("Time per layer per token:", T, "ms")
|
|
||||||
print("Lower bound total time per token:", T * 32, "ms")
|
|
@@ -1,199 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.mps
|
|
||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
|
||||||
if x.device != torch.device("cpu"):
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
class RoPE(nn.Module):
|
|
||||||
def __init__(self, dims: int, traditional: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.dims = dims
|
|
||||||
self.traditional = traditional
|
|
||||||
|
|
||||||
def _compute_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., : self.dims // 2]
|
|
||||||
x2 = x[..., self.dims // 2 : self.dims]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
|
|
||||||
else:
|
|
||||||
rx = torch.cat([rx1, rx2], dim=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"RoPE doesn't implement partial traditional application"
|
|
||||||
)
|
|
||||||
|
|
||||||
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
def forward(self, x, offset: int = 0):
|
|
||||||
shape = x.shape
|
|
||||||
x = x.view(-1, shape[-2], shape[-1])
|
|
||||||
N = x.shape[1] + offset
|
|
||||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
|
||||||
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
rope = (
|
|
||||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
|
||||||
)
|
|
||||||
rx = rope(costheta, sintheta, x)
|
|
||||||
|
|
||||||
return rx.view(*shape)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_cos_sin_theta(
|
|
||||||
N: int,
|
|
||||||
D: int,
|
|
||||||
offset: int = 0,
|
|
||||||
base: float = 10000,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.float32,
|
|
||||||
):
|
|
||||||
D = D // 2
|
|
||||||
positions = torch.arange(offset, N, dtype=dtype, device=device)
|
|
||||||
freqs = torch.exp(
|
|
||||||
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
|
|
||||||
)
|
|
||||||
theta = positions.view(-1, 1) * freqs.view(1, -1)
|
|
||||||
costheta = torch.cos(theta)
|
|
||||||
sintheta = torch.sin(theta)
|
|
||||||
|
|
||||||
return costheta, sintheta
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dims: int, epsilon: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.gamma = nn.Parameter(torch.ones((dims,)))
|
|
||||||
self.epsilon = epsilon
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
|
|
||||||
return self.gamma * x * n
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
|
||||||
def __init__(self, dims: int, num_heads: int):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.rope = RoPE(dims // num_heads, True)
|
|
||||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
|
||||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
|
||||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
|
||||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
|
||||||
|
|
||||||
def forward(self, queries, keys, values, mask=None, cache=None):
|
|
||||||
queries = self.query_proj(queries)
|
|
||||||
keys = self.key_proj(keys)
|
|
||||||
values = self.value_proj(values)
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
|
||||||
B, L, D = queries.shape
|
|
||||||
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
|
||||||
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
|
||||||
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
key_cache, value_cache = cache
|
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
||||||
keys = torch.cat([key_cache, keys], dim=2)
|
|
||||||
values = torch.cat([value_cache, values], dim=2)
|
|
||||||
else:
|
|
||||||
queries = self.rope(queries)
|
|
||||||
keys = self.rope(keys)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
|
||||||
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
|
|
||||||
if mask is not None:
|
|
||||||
scores = scores + mask
|
|
||||||
scores = torch.softmax(scores, dim=-1)
|
|
||||||
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
|
|
||||||
return self.out_proj(values_hat), (keys, values)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaEncoderLayer(nn.Module):
|
|
||||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.attention = LlamaAttention(dims, num_heads)
|
|
||||||
|
|
||||||
self.norm1 = RMSNorm(dims)
|
|
||||||
self.norm2 = RMSNorm(dims)
|
|
||||||
|
|
||||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
|
||||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
|
||||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x, mask=None, cache=None):
|
|
||||||
y = self.norm1(x)
|
|
||||||
y, cache = self.attention(y, y, y, mask, cache)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.norm2(x)
|
|
||||||
a = self.linear1(y)
|
|
||||||
b = self.linear2(y)
|
|
||||||
y = torch.nn.functional.silu(a) * b
|
|
||||||
y = self.linear3(y)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
return x, cache
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def measure(model, x, cache):
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
sync_if_needed(x)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
sync_if_needed(x)
|
|
||||||
end = time.time()
|
|
||||||
return (end - start) * 1000 / 5
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
H = 32
|
|
||||||
D = 4096
|
|
||||||
F = 43 * 256
|
|
||||||
C = 1000
|
|
||||||
device = torch.device("mps")
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
|
|
||||||
x = torch.randn(1, 1, D).to(device).to(dtype)
|
|
||||||
cache = [
|
|
||||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
|
||||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
|
||||||
]
|
|
||||||
|
|
||||||
T = measure(layer, x, cache)
|
|
||||||
|
|
||||||
print("Time per layer per token:", T, "ms")
|
|
||||||
print("Lower bound total time per token:", T * 32, "ms")
|
|
35
benchmarks/python/rope_bench.py
Normal file
35
benchmarks/python/rope_bench.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def time_rope():
|
||||||
|
rope = nn.RoPE(4096)
|
||||||
|
|
||||||
|
# vec
|
||||||
|
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def rope_vec(x):
|
||||||
|
for _ in range(32):
|
||||||
|
x = rope(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(rope_vec, x)
|
||||||
|
|
||||||
|
# matrix
|
||||||
|
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def rope_mat(x):
|
||||||
|
for _ in range(32):
|
||||||
|
x = rope(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(rope_mat, x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_rope()
|
56
benchmarks/python/scatter_bench.py
Normal file
56
benchmarks/python/scatter_bench.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
|
||||||
|
def scatter(dst, x, idx):
|
||||||
|
dst[idx] = x
|
||||||
|
mx.eval(dst)
|
||||||
|
|
||||||
|
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
|
||||||
|
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||||
|
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||||
|
|
||||||
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
|
||||||
|
print(f"MLX: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
|
||||||
|
def gather(dst, x, idx, device):
|
||||||
|
dst[idx] = x
|
||||||
|
if device == torch.device("mps"):
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
|
||||||
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
|
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||||
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device("mps")
|
||||||
|
|
||||||
|
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
|
||||||
|
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
|
||||||
|
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
|
||||||
|
|
||||||
|
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||||
|
print("=" * 20)
|
||||||
|
print(f"X {x_shape}, Indices {idx_shape}")
|
||||||
|
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||||
|
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
@@ -1,8 +1,8 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
@@ -44,6 +44,13 @@ def time_matmul():
|
|||||||
time_fn(mx.matmul, a, b)
|
time_fn(mx.matmul, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_maximum():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
b = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
mx.eval(a, b)
|
||||||
|
time_fn(mx.maximum, a, b)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -101,6 +108,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
time_negative()
|
time_negative()
|
||||||
time_logsumexp()
|
time_logsumexp()
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs):
|
|||||||
|
|
||||||
msec = 1e3 * (toc - tic) / num_iters
|
msec = 1e3 * (toc - tic) / num_iters
|
||||||
print(f"{msec:.5f} msec")
|
print(f"{msec:.5f} msec")
|
||||||
|
|
||||||
|
|
||||||
|
def measure_runtime(fn, **kwargs):
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
fn(**kwargs)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
iters = 100
|
||||||
|
for _ in range(iters):
|
||||||
|
fn(**kwargs)
|
||||||
|
return (time.time() - tic) * 1000 / iters
|
||||||
|
@@ -12,7 +12,7 @@ include(CMakeParseArguments)
|
|||||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
||||||
# SOURCES: List of source files
|
# SOURCES: List of source files
|
||||||
# INCLUDE_DIRS: List of include dirs
|
# INCLUDE_DIRS: List of include dirs
|
||||||
# DEPS: List of depedency files (like headers)
|
# DEPS: List of dependency files (like headers)
|
||||||
#
|
#
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
@@ -32,7 +32,7 @@ macro(mlx_build_metallib)
|
|||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||||
|
|
||||||
# Prepare metllib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||||
COMMAND xcrun -sdk macosx metal
|
COMMAND xcrun -sdk macosx metal
|
||||||
|
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1 +1,3 @@
|
|||||||
src/python/_autosummary*/
|
src/python/_autosummary*/
|
||||||
|
src/python/nn/_autosummary*/
|
||||||
|
src/python/optimizers/_autosummary*/
|
||||||
|
@@ -26,7 +26,7 @@ python -m http.server <port>
|
|||||||
|
|
||||||
and point your browser to `http://localhost:<port>`.
|
and point your browser to `http://localhost:<port>`.
|
||||||
|
|
||||||
### Push to Github Pages
|
### Push to GitHub Pages
|
||||||
|
|
||||||
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
|
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
|
||||||
the docs. Then force add the `build/html` directory:
|
the docs. Then force add the `build/html` directory:
|
||||||
|
33
docs/src/_templates/module-base-class.rst
Normal file
33
docs/src/_templates/module-base-class.rst
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
{{ fullname | escape | underline}}
|
||||||
|
|
||||||
|
.. currentmodule:: {{ module }}
|
||||||
|
|
||||||
|
.. add toctree option to make autodoc generate the pages
|
||||||
|
|
||||||
|
.. autoclass:: {{ objname }}
|
||||||
|
|
||||||
|
{% block attributes %}
|
||||||
|
{% if attributes %}
|
||||||
|
.. rubric:: Attributes
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: .
|
||||||
|
{% for item in attributes %}
|
||||||
|
~{{ fullname }}.{{ item }}
|
||||||
|
{%- endfor %}
|
||||||
|
{% endif %}
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
|
{% block methods %}
|
||||||
|
{% if methods %}
|
||||||
|
.. rubric:: Methods
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: .
|
||||||
|
{% for item in methods %}
|
||||||
|
{%- if item not in inherited_members and item != '__init__' %}
|
||||||
|
~{{ fullname }}.{{ item }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor %}
|
||||||
|
{% endif %}
|
||||||
|
{% endblock %}
|
@@ -1,19 +0,0 @@
|
|||||||
{{ fullname | escape | underline}}
|
|
||||||
|
|
||||||
.. currentmodule:: {{ module }}
|
|
||||||
|
|
||||||
.. autoclass:: {{ objname }}
|
|
||||||
|
|
||||||
{#{% block methods %}
|
|
||||||
|
|
||||||
{% if methods %}
|
|
||||||
.. rubric:: {{ _('Methods') }}
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
{% for item in methods %}
|
|
||||||
{%- if item not in inherited_members and item != '__init__' %}
|
|
||||||
~{{ name }}.{{ item }}
|
|
||||||
{%- endif %}
|
|
||||||
{%- endfor %}
|
|
||||||
{% endif %}
|
|
||||||
{% endblock %}#}
|
|
@@ -5,13 +5,15 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, MLX Contributors"
|
copyright = "2023, MLX Contributors"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = "0.0.4"
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = "0.0.4"
|
release = version
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
@@ -24,6 +26,7 @@ extensions = [
|
|||||||
|
|
||||||
python_use_unqualified_type_names = True
|
python_use_unqualified_type_names = True
|
||||||
autosummary_generate = True
|
autosummary_generate = True
|
||||||
|
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||||
|
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
"https://docs.python.org/3": None,
|
"https://docs.python.org/3": None,
|
||||||
|
@@ -15,7 +15,7 @@ Introducing the Example
|
|||||||
-----------------------
|
-----------------------
|
||||||
|
|
||||||
Let's say that you would like an operation that takes in two arrays,
|
Let's say that you would like an operation that takes in two arrays,
|
||||||
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
|
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
|
||||||
respectively, and then adds them together to get the result
|
respectively, and then adds them together to get the result
|
||||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||||
writing out a function as follows:
|
writing out a function as follows:
|
||||||
@@ -35,7 +35,7 @@ However, you work with vector math libraries often and realize that the
|
|||||||
You would really like the part of your applications that does this operation
|
You would really like the part of your applications that does this operation
|
||||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
||||||
our assumptions on to you, let's also assume that you want to learn how add
|
our assumptions on to you, let's also assume that you want to learn how to add
|
||||||
your own implementation for the gradients of your new operation while going
|
your own implementation for the gradients of your new operation while going
|
||||||
over the ins-and-outs of the MLX framework.
|
over the ins-and-outs of the MLX framework.
|
||||||
|
|
||||||
@@ -69,7 +69,7 @@ C++ API:
|
|||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Scale and sum two vectors elementwise
|
* Scale and sum two vectors element-wise
|
||||||
* z = alpha * x + beta * y
|
* z = alpha * x + beta * y
|
||||||
*
|
*
|
||||||
* Follow numpy style broadcasting between x and y
|
* Follow numpy style broadcasting between x and y
|
||||||
@@ -150,7 +150,7 @@ back and go to our example to give ourselves a more concrete image.
|
|||||||
const std::vector<int>& argnums) override;
|
const std::vector<int>& argnums) override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The primitive must know how to vectorize itself accross
|
* The primitive must know how to vectorize itself across
|
||||||
* the given axes. The output is a pair containing the array
|
* the given axes. The output is a pair containing the array
|
||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
@@ -230,7 +230,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
|||||||
|
|
||||||
This operation now handles the following:
|
This operation now handles the following:
|
||||||
|
|
||||||
#. Upcast inputs and resolve the the output data type.
|
#. Upcast inputs and resolve the output data type.
|
||||||
#. Broadcast the inputs and resolve the output shape.
|
#. Broadcast the inputs and resolve the output shape.
|
||||||
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
||||||
#. Construct the output :class:`array` using the primitive and the inputs.
|
#. Construct the output :class:`array` using the primitive and the inputs.
|
||||||
@@ -284,14 +284,14 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
|||||||
T alpha = static_cast<T>(alpha_);
|
T alpha = static_cast<T>(alpha_);
|
||||||
T beta = static_cast<T>(beta_);
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
// Do the elementwise operation for each output
|
// Do the element-wise operation for each output
|
||||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||||
// Map linear indices to offsets in x and y
|
// Map linear indices to offsets in x and y
|
||||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||||
|
|
||||||
// We allocate the output to be contiguous and regularly strided
|
// We allocate the output to be contiguous and regularly strided
|
||||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
// (defaults to row major) and hence it doesn't need additional mapping
|
||||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -305,7 +305,7 @@ if we encounter an unexpected type.
|
|||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||||
// Check the inputs (registered in the op while contructing the out array)
|
// Check the inputs (registered in the op while constructing the out array)
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
@@ -485,7 +485,7 @@ each data type.
|
|||||||
|
|
||||||
instantiate_axpby(float32, float);
|
instantiate_axpby(float32, float);
|
||||||
instantiate_axpby(float16, half);
|
instantiate_axpby(float16, half);
|
||||||
instantiate_axpby(bflot16, bfloat16_t);
|
instantiate_axpby(bfloat16, bfloat16_t);
|
||||||
instantiate_axpby(complex64, complex64_t);
|
instantiate_axpby(complex64, complex64_t);
|
||||||
|
|
||||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||||
@@ -537,7 +537,7 @@ below.
|
|||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
// those in the kernel decelaration at axpby.metal
|
// those in the kernel declaration at axpby.metal
|
||||||
int ndim = out.ndim();
|
int ndim = out.ndim();
|
||||||
size_t nelem = out.size();
|
size_t nelem = out.size();
|
||||||
|
|
||||||
@@ -568,7 +568,7 @@ below.
|
|||||||
// Fix the 3D size of the launch grid (in terms of threads)
|
// Fix the 3D size of the launch grid (in terms of threads)
|
||||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||||
|
|
||||||
// Launch the grid with the given number of threads divded among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
@@ -581,7 +581,7 @@ to give us the active metal compute command encoder instead of building a
|
|||||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||||
until some specified limit is hit or the compute encoder needs to be flushed
|
until some specified limit is hit or the compute encoder needs to be flushed
|
||||||
for synchronization. MLX also handles enqueuing and commiting the associated
|
for synchronization. MLX also handles enqueuing and committing the associated
|
||||||
command buffers as needed. We suggest taking a deeper dive into
|
command buffers as needed. We suggest taking a deeper dive into
|
||||||
:class:`metal::Device` if you would like to study this routine further.
|
:class:`metal::Device` if you would like to study this routine further.
|
||||||
|
|
||||||
@@ -601,8 +601,8 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
|||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the the primitive can built with ops
|
// The jvp transform on the primitive can built with ops
|
||||||
// that are scheduled on the same stream as the primtive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
// jvp is just the tangent scaled by alpha
|
// jvp is just the tangent scaled by alpha
|
||||||
@@ -642,7 +642,7 @@ own :class:`Primitive`.
|
|||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Vectorize primitve along given axis */
|
/** Vectorize primitive along given axis */
|
||||||
std::pair<array, int> Axpby::vmap(
|
std::pair<array, int> Axpby::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
@@ -666,7 +666,7 @@ Let's look at the overall directory structure first.
|
|||||||
| └── setup.py
|
| └── setup.py
|
||||||
|
|
||||||
* ``extensions/axpby/`` defines the C++ extension library
|
* ``extensions/axpby/`` defines the C++ extension library
|
||||||
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
|
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||||
associated python package
|
associated python package
|
||||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||||
@@ -677,9 +677,9 @@ Let's look at the overall directory structure first.
|
|||||||
Binding to Python
|
Binding to Python
|
||||||
^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings
|
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
|
||||||
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
|
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
||||||
are already provided, adding our :meth:`axpby` becomes very simple!
|
already provided, adding our :meth:`axpby` is simple!
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -697,7 +697,7 @@ are already provided, adding our :meth:`axpby` becomes very simple!
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = py::none(),
|
"stream"_a = py::none(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Scale and sum two vectors elementwise
|
Scale and sum two vectors element-wise
|
||||||
``z = alpha * x + beta * y``
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
Follows numpy style broadcasting between ``x`` and ``y``
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
@@ -840,7 +840,7 @@ This will result in a directory structure as follows:
|
|||||||
| ...
|
| ...
|
||||||
|
|
||||||
When you try to install using the command ``python -m pip install .``
|
When you try to install using the command ``python -m pip install .``
|
||||||
(in ``extensions/``), the package will be installed with the same strucutre as
|
(in ``extensions/``), the package will be installed with the same structure as
|
||||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||||
copied along with the python binding since they are specified as ``package_data``.
|
copied along with the python binding since they are specified as ``package_data``.
|
||||||
|
|
||||||
@@ -927,18 +927,18 @@ Results:
|
|||||||
|
|
||||||
We see some modest improvements right away!
|
We see some modest improvements right away!
|
||||||
|
|
||||||
This operation is now good to be used to build other operations,
|
This operation is now good to be used to build other operations, in
|
||||||
in :class:`mlx.nn.Module` calls, and also as a part of graph
|
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
||||||
transformations such as :meth:`grad` and :meth:`simplify`!
|
:meth:`grad`!
|
||||||
|
|
||||||
Scripts
|
Scripts
|
||||||
-------
|
-------
|
||||||
|
|
||||||
.. admonition:: Download the code
|
.. admonition:: Download the code
|
||||||
|
|
||||||
The full example code is available in `mlx-examples <code>`_.
|
The full example code is available in `mlx <code>`_.
|
||||||
|
|
||||||
.. code: `TODO_LINK/extensions`_
|
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
|
||||||
|
|
||||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||||
|
@@ -371,7 +371,7 @@ Scripts
|
|||||||
|
|
||||||
The full example code is available in `mlx-examples`_.
|
The full example code is available in `mlx-examples`_.
|
||||||
|
|
||||||
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llama
|
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
|
||||||
|
|
||||||
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
|
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
|
||||||
Roformer: Enhanced transformer with rotary position embedding. arXiv
|
Roformer: Enhanced transformer with rotary position embedding. arXiv
|
||||||
|
@@ -61,7 +61,10 @@ set:
|
|||||||
def eval_fn(model, X, y):
|
def eval_fn(model, X, y):
|
||||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||||
|
|
||||||
Next, setup the problem parameters and load the data:
|
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||||
|
`mnist data loader
|
||||||
|
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||||
|
we will import as `mnist`.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@@ -19,7 +19,7 @@ The main differences between MLX and NumPy are:
|
|||||||
|
|
||||||
The design of MLX is inspired by frameworks like `PyTorch
|
The design of MLX is inspired by frameworks like `PyTorch
|
||||||
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||||
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
|
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
|
||||||
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||||
memory. Operations on MLX arrays can be performed on any of the supported
|
memory. Operations on MLX arrays can be performed on any of the supported
|
||||||
device types without performing data copies. Currently supported device types
|
device types without performing data copies. Currently supported device types
|
||||||
@@ -35,8 +35,15 @@ are the CPU and GPU.
|
|||||||
:caption: Usage
|
:caption: Usage
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
quick_start
|
usage/quick_start
|
||||||
using_streams
|
usage/lazy_evaluation
|
||||||
|
usage/unified_memory
|
||||||
|
usage/indexing
|
||||||
|
usage/saving_and_loading
|
||||||
|
usage/function_transforms
|
||||||
|
usage/compile
|
||||||
|
usage/numpy
|
||||||
|
usage/using_streams
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:caption: Examples
|
:caption: Examples
|
||||||
@@ -56,6 +63,7 @@ are the CPU and GPU.
|
|||||||
python/random
|
python/random
|
||||||
python/transforms
|
python/transforms
|
||||||
python/fft
|
python/fft
|
||||||
|
python/linalg
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
python/tree_utils
|
python/tree_utils
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
Build and Install
|
Build and Install
|
||||||
=================
|
=================
|
||||||
|
|
||||||
Install from PyPI
|
Python Installation
|
||||||
-----------------
|
-------------------
|
||||||
|
|
||||||
MLX is available on PyPI. All you have to do to use MLX with your own Apple
|
MLX is available on PyPI. All you have to do to use MLX with your own Apple
|
||||||
silicon computer is
|
silicon computer is
|
||||||
@@ -11,9 +11,40 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
|
To install from PyPI you must meet the following requirements:
|
||||||
|
|
||||||
|
- Using an M series chip (Apple silicon)
|
||||||
|
- Using a native Python >= 3.8
|
||||||
|
- macOS >= 13.3
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
MLX is only available on devices running MacOS >= 13.3
|
MLX is only available on devices running macOS >= 13.3
|
||||||
It is highly recommended to use MacOS 14 (Sonoma)
|
It is highly recommended to use macOS 14 (Sonoma)
|
||||||
|
|
||||||
|
|
||||||
|
MLX is also available on conda-forge. To install MLX with conda do:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
conda install conda-forge::mlx
|
||||||
|
|
||||||
|
|
||||||
|
Troubleshooting
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
*My OS and Python versions are in the required range but pip still does not find
|
||||||
|
a matching distribution.*
|
||||||
|
|
||||||
|
Probably you are using a non-native Python. The output of
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
python -c "import platform; print(platform.processor())"
|
||||||
|
|
||||||
|
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
|
||||||
|
are using a non-native Python. Switch your Python to a native Python. A good
|
||||||
|
way to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.
|
||||||
|
|
||||||
|
|
||||||
Build from source
|
Build from source
|
||||||
-----------------
|
-----------------
|
||||||
@@ -23,8 +54,11 @@ Build Requirements
|
|||||||
|
|
||||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||||
- Xcode >= 14.3 (Xcode >= 15.0 for MacOS 14 and above)
|
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
||||||
|
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.
|
||||||
|
|
||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
@@ -60,9 +94,17 @@ For developing use an editable install:
|
|||||||
To make sure the install is working run the tests with:
|
To make sure the install is working run the tests with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install ".[testing]"
|
pip install ".[testing]"
|
||||||
python -m unittest discover python/tests
|
python -m unittest discover python/tests
|
||||||
|
|
||||||
|
Optional: Install stubs to enable auto completions and type checking from your IDE:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install ".[dev]"
|
||||||
|
python setup.py generate_stubs
|
||||||
|
|
||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
@@ -129,8 +171,64 @@ should point to the path to the built metal library.
|
|||||||
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
||||||
|
|
||||||
Further, you can use the following command to find out which
|
Further, you can use the following command to find out which
|
||||||
MacOS SDK will be used
|
macOS SDK will be used
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
Troubleshooting
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
|
Metal not found
|
||||||
|
~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
You see the following error when you try to build:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
error: unable to find utility "metal", not a developer tool or in PATH
|
||||||
|
|
||||||
|
To fix this, first make sure you have Xcode installed:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
xcode-select --install
|
||||||
|
|
||||||
|
Then set the active developer directory:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
||||||
|
|
||||||
|
x86 Shell
|
||||||
|
~~~~~~~~~
|
||||||
|
|
||||||
|
.. _build shell:
|
||||||
|
|
||||||
|
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||||
|
Rosetta instead of natively.
|
||||||
|
|
||||||
|
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||||
|
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
|
||||||
|
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
|
||||||
|
terminal.
|
||||||
|
|
||||||
|
Verify the terminal is now running natively the following command:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
$ uname -p
|
||||||
|
arm
|
||||||
|
|
||||||
|
Also check that cmake is using the correct architecture:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
$ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR
|
||||||
|
CMAKE_HOST_SYSTEM_PROCESSOR "arm64"
|
||||||
|
|
||||||
|
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||||
|
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||||
|
wipe your build cahce with ``rm -rf build/`` and try again.
|
||||||
|
@@ -34,6 +34,7 @@ Array
|
|||||||
array.prod
|
array.prod
|
||||||
array.reciprocal
|
array.reciprocal
|
||||||
array.reshape
|
array.reshape
|
||||||
|
array.round
|
||||||
array.rsqrt
|
array.rsqrt
|
||||||
array.sin
|
array.sin
|
||||||
array.split
|
array.split
|
||||||
|
@@ -29,9 +29,9 @@ The default floating point type is ``float32`` and the default integer type is
|
|||||||
* - ``uint32``
|
* - ``uint32``
|
||||||
- 4
|
- 4
|
||||||
- 32-bit unsigned integer
|
- 32-bit unsigned integer
|
||||||
* - ``uint32``
|
* - ``uint64``
|
||||||
- 8
|
- 8
|
||||||
- 32-bit unsigned integer
|
- 64-bit unsigned integer
|
||||||
* - ``int8``
|
* - ``int8``
|
||||||
- 1
|
- 1
|
||||||
- 8-bit signed integer
|
- 8-bit signed integer
|
||||||
|
@@ -9,9 +9,10 @@ Devices and Streams
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
Device
|
Device
|
||||||
|
Stream
|
||||||
default_device
|
default_device
|
||||||
set_default_device
|
set_default_device
|
||||||
Stream
|
|
||||||
default_stream
|
default_stream
|
||||||
new_stream
|
new_stream
|
||||||
set_default_stream
|
set_default_stream
|
||||||
|
stream
|
||||||
|
12
docs/src/python/linalg.rst
Normal file
12
docs/src/python/linalg.rst
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
.. _linalg:
|
||||||
|
|
||||||
|
Linear Algebra
|
||||||
|
==============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.linalg
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
norm
|
||||||
|
qr
|
@@ -64,7 +64,6 @@ Quick Start with Neural Networks
|
|||||||
# gradient with respect to `mlp.trainable_parameters()`
|
# gradient with respect to `mlp.trainable_parameters()`
|
||||||
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
||||||
|
|
||||||
|
|
||||||
.. _module_class:
|
.. _module_class:
|
||||||
|
|
||||||
The Module Class
|
The Module Class
|
||||||
@@ -86,20 +85,58 @@ name should not start with ``_``). It can be arbitrarily nested in other
|
|||||||
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
||||||
the parameters of a module and its submodules.
|
the parameters of a module and its submodules.
|
||||||
|
|
||||||
A :class:`Module` can also keep track of "frozen" parameters.
|
A :class:`Module` can also keep track of "frozen" parameters. See the
|
||||||
:meth:`Module.trainable_parameters` returns only the subset of
|
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
|
||||||
:meth:`Module.parameters` that is not frozen. When using
|
the gradients returned will be with respect to these trainable parameters.
|
||||||
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
|
|
||||||
trainable parameters.
|
|
||||||
|
|
||||||
Updating the parameters
|
|
||||||
|
Updating the Parameters
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
MLX modules allow accessing and updating individual parameters. However, most
|
MLX modules allow accessing and updating individual parameters. However, most
|
||||||
times we need to update large subsets of a module's parameters. This action is
|
times we need to update large subsets of a module's parameters. This action is
|
||||||
performed by :meth:`Module.update`.
|
performed by :meth:`Module.update`.
|
||||||
|
|
||||||
Value and grad
|
|
||||||
|
Inspecting Modules
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The simplest way to see the model architecture is to print it. Following along with
|
||||||
|
the above example, you can print the ``MLP`` with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
print(mlp)
|
||||||
|
|
||||||
|
This will display:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
MLP(
|
||||||
|
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
|
||||||
|
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
|
||||||
|
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
To get more detailed information on the arrays in a :class:`Module` you can use
|
||||||
|
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
|
||||||
|
all the parameters in a :class:`Module` do:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
||||||
|
|
||||||
|
As another example, you can count the number of parameters in a :class:`Module`
|
||||||
|
with:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
||||||
|
|
||||||
|
|
||||||
|
Value and Grad
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
Using a :class:`Module` does not preclude using MLX's high order function
|
Using a :class:`Module` does not preclude using MLX's high order function
|
||||||
@@ -137,36 +174,10 @@ In detail:
|
|||||||
|
|
||||||
value_and_grad
|
value_and_grad
|
||||||
|
|
||||||
Neural Network Layers
|
.. toctree::
|
||||||
---------------------
|
|
||||||
|
|
||||||
.. autosummary::
|
nn/module
|
||||||
:toctree: _autosummary
|
nn/layers
|
||||||
:template: nn-module-template.rst
|
nn/functions
|
||||||
|
nn/losses
|
||||||
Embedding
|
nn/init
|
||||||
ReLU
|
|
||||||
GELU
|
|
||||||
SiLU
|
|
||||||
Linear
|
|
||||||
Conv1d
|
|
||||||
Conv2d
|
|
||||||
LayerNorm
|
|
||||||
RMSNorm
|
|
||||||
GroupNorm
|
|
||||||
RoPE
|
|
||||||
MultiHeadAttention
|
|
||||||
Sequential
|
|
||||||
|
|
||||||
Layers without parameters (e.g. activation functions) are also provided as
|
|
||||||
simple functions.
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: _autosummary_functions
|
|
||||||
:template: nn-module-template.rst
|
|
||||||
|
|
||||||
gelu
|
|
||||||
gelu_approx
|
|
||||||
gelu_fast_approx
|
|
||||||
relu
|
|
||||||
silu
|
|
||||||
|
24
docs/src/python/nn/functions.rst
Normal file
24
docs/src/python/nn/functions.rst
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
.. _nn_functions:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
Functions
|
||||||
|
---------
|
||||||
|
|
||||||
|
Layers without parameters (e.g. activation functions) are also provided as
|
||||||
|
simple functions.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary_functions
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
gelu
|
||||||
|
gelu_approx
|
||||||
|
gelu_fast_approx
|
||||||
|
mish
|
||||||
|
prelu
|
||||||
|
relu
|
||||||
|
selu
|
||||||
|
softshrink
|
||||||
|
silu
|
||||||
|
step
|
45
docs/src/python/nn/init.rst
Normal file
45
docs/src/python/nn/init.rst
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
.. _init:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn.init
|
||||||
|
|
||||||
|
Initializers
|
||||||
|
------------
|
||||||
|
|
||||||
|
The ``mlx.nn.init`` package contains commonly used initializers for neural
|
||||||
|
network parameters. Initializers return a function which can be applied to any
|
||||||
|
input :obj:`mlx.core.array` to produce an initialized output.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
init_fn = nn.init.uniform()
|
||||||
|
|
||||||
|
# Produces a [2, 2] uniform matrix
|
||||||
|
param = init_fn(mx.zeros((2, 2)))
|
||||||
|
|
||||||
|
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
|
||||||
|
distribution, you can do:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
|
||||||
|
init_fn = nn.init.uniform(low=-0.1, high=0.1)
|
||||||
|
model.apply(init_fn)
|
||||||
|
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
constant
|
||||||
|
normal
|
||||||
|
uniform
|
||||||
|
identity
|
||||||
|
glorot_normal
|
||||||
|
glorot_uniform
|
||||||
|
he_normal
|
||||||
|
he_uniform
|
42
docs/src/python/nn/layers.rst
Normal file
42
docs/src/python/nn/layers.rst
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
.. _layers:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
|
Layers
|
||||||
|
------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
ALiBi
|
||||||
|
AvgPool1d
|
||||||
|
AvgPool2d
|
||||||
|
BatchNorm
|
||||||
|
Conv1d
|
||||||
|
Conv2d
|
||||||
|
Dropout
|
||||||
|
Dropout2d
|
||||||
|
Dropout3d
|
||||||
|
Embedding
|
||||||
|
GELU
|
||||||
|
GroupNorm
|
||||||
|
InstanceNorm
|
||||||
|
LayerNorm
|
||||||
|
Linear
|
||||||
|
MaxPool1d
|
||||||
|
MaxPool2d
|
||||||
|
Mish
|
||||||
|
MultiHeadAttention
|
||||||
|
PReLU
|
||||||
|
QuantizedLinear
|
||||||
|
RMSNorm
|
||||||
|
ReLU
|
||||||
|
RoPE
|
||||||
|
SELU
|
||||||
|
Sequential
|
||||||
|
SiLU
|
||||||
|
SinusoidalPositionalEncoding
|
||||||
|
Softshrink
|
||||||
|
Step
|
||||||
|
Transformer
|
25
docs/src/python/nn/losses.rst
Normal file
25
docs/src/python/nn/losses.rst
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
.. _losses:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn.losses
|
||||||
|
|
||||||
|
Loss Functions
|
||||||
|
--------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary_functions
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
binary_cross_entropy
|
||||||
|
cosine_similarity_loss
|
||||||
|
cross_entropy
|
||||||
|
gaussian_nll_loss
|
||||||
|
hinge_loss
|
||||||
|
huber_loss
|
||||||
|
kl_div_loss
|
||||||
|
l1_loss
|
||||||
|
log_cosh_loss
|
||||||
|
margin_ranking_loss
|
||||||
|
mse_loss
|
||||||
|
nll_loss
|
||||||
|
smooth_l1_loss
|
||||||
|
triplet_loss
|
@@ -1,7 +1,37 @@
|
|||||||
mlx.nn.Module
|
Module
|
||||||
=============
|
======
|
||||||
|
|
||||||
.. currentmodule:: mlx.nn
|
.. currentmodule:: mlx.nn
|
||||||
|
|
||||||
.. autoclass:: Module
|
.. autoclass:: Module
|
||||||
:members:
|
|
||||||
|
.. rubric:: Attributes
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Module.training
|
||||||
|
Module.state
|
||||||
|
|
||||||
|
.. rubric:: Methods
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Module.apply
|
||||||
|
Module.apply_to_modules
|
||||||
|
Module.children
|
||||||
|
Module.eval
|
||||||
|
Module.filter_and_map
|
||||||
|
Module.freeze
|
||||||
|
Module.leaf_modules
|
||||||
|
Module.load_weights
|
||||||
|
Module.modules
|
||||||
|
Module.named_modules
|
||||||
|
Module.parameters
|
||||||
|
Module.save_weights
|
||||||
|
Module.train
|
||||||
|
Module.trainable_parameters
|
||||||
|
Module.unfreeze
|
||||||
|
Module.update
|
||||||
|
Module.update_modules
|
||||||
|
@@ -26,23 +26,40 @@ Operations
|
|||||||
argsort
|
argsort
|
||||||
array_equal
|
array_equal
|
||||||
broadcast_to
|
broadcast_to
|
||||||
|
ceil
|
||||||
|
clip
|
||||||
concatenate
|
concatenate
|
||||||
convolve
|
convolve
|
||||||
conv1d
|
conv1d
|
||||||
conv2d
|
conv2d
|
||||||
cos
|
cos
|
||||||
cosh
|
cosh
|
||||||
|
dequantize
|
||||||
|
diag
|
||||||
|
diagonal
|
||||||
divide
|
divide
|
||||||
|
divmod
|
||||||
equal
|
equal
|
||||||
erf
|
erf
|
||||||
erfinv
|
erfinv
|
||||||
exp
|
exp
|
||||||
expand_dims
|
expand_dims
|
||||||
|
eye
|
||||||
|
flatten
|
||||||
|
floor
|
||||||
|
floor_divide
|
||||||
full
|
full
|
||||||
greater
|
greater
|
||||||
greater_equal
|
greater_equal
|
||||||
|
identity
|
||||||
|
inner
|
||||||
|
isnan
|
||||||
|
isposinf
|
||||||
|
isneginf
|
||||||
|
isinf
|
||||||
less
|
less
|
||||||
less_equal
|
less_equal
|
||||||
|
linspace
|
||||||
load
|
load
|
||||||
log
|
log
|
||||||
log2
|
log2
|
||||||
@@ -50,6 +67,8 @@ Operations
|
|||||||
log1p
|
log1p
|
||||||
logaddexp
|
logaddexp
|
||||||
logical_not
|
logical_not
|
||||||
|
logical_and
|
||||||
|
logical_or
|
||||||
logsumexp
|
logsumexp
|
||||||
matmul
|
matmul
|
||||||
max
|
max
|
||||||
@@ -57,19 +76,27 @@ Operations
|
|||||||
mean
|
mean
|
||||||
min
|
min
|
||||||
minimum
|
minimum
|
||||||
|
moveaxis
|
||||||
multiply
|
multiply
|
||||||
negative
|
negative
|
||||||
ones
|
ones
|
||||||
ones_like
|
ones_like
|
||||||
|
outer
|
||||||
partition
|
partition
|
||||||
pad
|
pad
|
||||||
prod
|
prod
|
||||||
|
quantize
|
||||||
|
quantized_matmul
|
||||||
reciprocal
|
reciprocal
|
||||||
|
repeat
|
||||||
reshape
|
reshape
|
||||||
|
round
|
||||||
rsqrt
|
rsqrt
|
||||||
save
|
save
|
||||||
savez
|
savez
|
||||||
savez_compressed
|
savez_compressed
|
||||||
|
save_gguf
|
||||||
|
save_safetensors
|
||||||
sigmoid
|
sigmoid
|
||||||
sign
|
sign
|
||||||
sin
|
sin
|
||||||
@@ -80,14 +107,20 @@ Operations
|
|||||||
sqrt
|
sqrt
|
||||||
square
|
square
|
||||||
squeeze
|
squeeze
|
||||||
|
stack
|
||||||
stop_gradient
|
stop_gradient
|
||||||
subtract
|
subtract
|
||||||
sum
|
sum
|
||||||
|
swapaxes
|
||||||
take
|
take
|
||||||
take_along_axis
|
take_along_axis
|
||||||
tan
|
tan
|
||||||
tanh
|
tanh
|
||||||
|
tensordot
|
||||||
transpose
|
transpose
|
||||||
|
tri
|
||||||
|
tril
|
||||||
|
triu
|
||||||
var
|
var
|
||||||
where
|
where
|
||||||
zeros
|
zeros
|
||||||
|
@@ -29,13 +29,8 @@ model's parameters and the **optimizer state**.
|
|||||||
# Compute the new parameters but also the optimizer state.
|
# Compute the new parameters but also the optimizer state.
|
||||||
mx.eval(model.parameters(), optimizer.state)
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
.. currentmodule:: mlx.optimizers
|
.. toctree::
|
||||||
|
|
||||||
.. autosummary::
|
optimizers/optimizer
|
||||||
:toctree: _autosummary
|
optimizers/common_optimizers
|
||||||
:template: optimizers-template.rst
|
optimizers/schedulers
|
||||||
|
|
||||||
OptimizerState
|
|
||||||
Optimizer
|
|
||||||
SGD
|
|
||||||
Adam
|
|
||||||
|
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
.. _common_optimizers:
|
||||||
|
|
||||||
|
Common Optimizers
|
||||||
|
=================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
:template: optimizers-template.rst
|
||||||
|
|
||||||
|
SGD
|
||||||
|
RMSprop
|
||||||
|
Adagrad
|
||||||
|
Adafactor
|
||||||
|
AdaDelta
|
||||||
|
Adam
|
||||||
|
AdamW
|
||||||
|
Adamax
|
||||||
|
Lion
|
23
docs/src/python/optimizers/optimizer.rst
Normal file
23
docs/src/python/optimizers/optimizer.rst
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
Optimizer
|
||||||
|
=========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autoclass:: Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
.. rubric:: Attributes
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Optimizer.state
|
||||||
|
|
||||||
|
.. rubric:: Methods
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Optimizer.apply_gradients
|
||||||
|
Optimizer.init
|
||||||
|
Optimizer.update
|
13
docs/src/python/optimizers/schedulers.rst
Normal file
13
docs/src/python/optimizers/schedulers.rst
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
.. _schedulers:
|
||||||
|
|
||||||
|
Schedulers
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
step_decay
|
||||||
|
exponential_decay
|
||||||
|
cosine_decay
|
@@ -33,13 +33,13 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
|||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
seed
|
|
||||||
key
|
|
||||||
split
|
|
||||||
bernoulli
|
bernoulli
|
||||||
categorical
|
categorical
|
||||||
gumbel
|
gumbel
|
||||||
|
key
|
||||||
normal
|
normal
|
||||||
randint
|
randint
|
||||||
uniform
|
seed
|
||||||
|
split
|
||||||
truncated_normal
|
truncated_normal
|
||||||
|
uniform
|
||||||
|
@@ -9,6 +9,9 @@ Transforms
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
eval
|
eval
|
||||||
|
compile
|
||||||
|
disable_compile
|
||||||
|
enable_compile
|
||||||
grad
|
grad
|
||||||
value_and_grad
|
value_and_grad
|
||||||
jvp
|
jvp
|
||||||
|
430
docs/src/usage/compile.rst
Normal file
430
docs/src/usage/compile.rst
Normal file
@@ -0,0 +1,430 @@
|
|||||||
|
.. _compile:
|
||||||
|
|
||||||
|
Compilation
|
||||||
|
===========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
MLX has a :func:`compile` function transformation which compiles computation
|
||||||
|
graphs. Function compilation results in smaller graphs by merging common work
|
||||||
|
and fusing certain operations. In many cases this can lead to big improvements
|
||||||
|
in run-time and memory use.
|
||||||
|
|
||||||
|
Getting started with :func:`compile` is simple, but there are some edge cases
|
||||||
|
that are good to be aware of for more complex graphs and advanced usage.
|
||||||
|
|
||||||
|
Basics of Compile
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
Let's start with a simple example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return mx.exp(-x) + y
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
y = mx.array(2.0)
|
||||||
|
|
||||||
|
# Regular call, no compilation
|
||||||
|
# Prints: array(2.36788, dtype=float32)
|
||||||
|
print(fun(x, y))
|
||||||
|
|
||||||
|
# Compile the function
|
||||||
|
compiled_fun = mx.compile(fun)
|
||||||
|
|
||||||
|
# Prints: array(2.36788, dtype=float32)
|
||||||
|
print(compiled_fun(x, y))
|
||||||
|
|
||||||
|
The output of both the regular function and the compiled function is the same
|
||||||
|
up to numerical precision.
|
||||||
|
|
||||||
|
The first time you call a compiled function, MLX will build the compute
|
||||||
|
graph, optimize it, and generate and compile code. This can be relatively
|
||||||
|
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||||
|
function multiple times will not initiate a new compilation. This means you
|
||||||
|
should typically compile functions that you plan to use more than once.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return mx.exp(-x) + y
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
y = mx.array(2.0)
|
||||||
|
|
||||||
|
compiled_fun = mx.compile(fun)
|
||||||
|
|
||||||
|
# Compiled here
|
||||||
|
compiled_fun(x, y)
|
||||||
|
|
||||||
|
# Not compiled again
|
||||||
|
compiled_fun(x, y)
|
||||||
|
|
||||||
|
# Not compiled again
|
||||||
|
mx.compile(fun)(x, y)
|
||||||
|
|
||||||
|
There are some important cases to be aware of that can cause a function to
|
||||||
|
be recompiled:
|
||||||
|
|
||||||
|
* Changing the shape or number of dimensions
|
||||||
|
* Changing the type of any of the inputs
|
||||||
|
* Changing the number of inputs to the function
|
||||||
|
|
||||||
|
In certain cases only some of the compilation stack will be rerun (for
|
||||||
|
example when changing the shapes) and in other cases the full compilation
|
||||||
|
stack will be rerun (for example when changing the types). In general you
|
||||||
|
should avoid compiling functions too frequently.
|
||||||
|
|
||||||
|
Another idiom to watch out for is compiling functions which get created and
|
||||||
|
destroyed frequently. This can happen, for example, when compiling an anonymous
|
||||||
|
function in a loop:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
a = mx.array(1.0)
|
||||||
|
# Don't do this, compiles lambda at each iteration
|
||||||
|
for _ in range(5):
|
||||||
|
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
|
||||||
|
|
||||||
|
Example Speedup
|
||||||
|
---------------
|
||||||
|
|
||||||
|
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
|
||||||
|
Transformer-based models. The implementation involves several unary and binary
|
||||||
|
element-wise operations:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def gelu(x):
|
||||||
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
|
If you use this function with small arrays, it will be overhead bound. If you
|
||||||
|
use it with large arrays it will be memory bandwidth bound. However, all of
|
||||||
|
the operations in the ``gelu`` are fusible into a single kernel with
|
||||||
|
:func:`compile`. This can speedup both cases considerably.
|
||||||
|
|
||||||
|
Let's compare the runtime of the regular function versus the compiled
|
||||||
|
function. We'll use the following timing helper which does a warm up and
|
||||||
|
handles synchronization:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
def timeit(fun, x):
|
||||||
|
# warm up
|
||||||
|
for _ in range(10):
|
||||||
|
mx.eval(fun(x))
|
||||||
|
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(100):
|
||||||
|
mx.eval(fun(x))
|
||||||
|
toc = time.perf_counter()
|
||||||
|
tpi = 1e3 * (toc - tic) / 100
|
||||||
|
print(f"Time per iteration {tpi:.3f} (ms)")
|
||||||
|
|
||||||
|
|
||||||
|
Now make an array, and benchmark both functions:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||||
|
timeit(nn.gelu, x)
|
||||||
|
timeit(mx.compile(nn.gelu), x)
|
||||||
|
|
||||||
|
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||||
|
five times faster.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
||||||
|
functions can still be helpful, but won't typically result in as large a
|
||||||
|
speedup as compiling operations that run on the GPU.
|
||||||
|
|
||||||
|
|
||||||
|
Debugging
|
||||||
|
---------
|
||||||
|
|
||||||
|
When a compiled function is first called, it is traced with placeholder
|
||||||
|
inputs. This means you can't evaluate arrays (for example to print their
|
||||||
|
contents) inside compiled functions.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x):
|
||||||
|
z = -x
|
||||||
|
print(z) # Crash
|
||||||
|
return mx.exp(z)
|
||||||
|
|
||||||
|
fun(mx.array(5.0))
|
||||||
|
|
||||||
|
For debugging, inspecting arrays can be helpful. One way to do that is to
|
||||||
|
globally disable compilation using the :func:`disable_compile` function or
|
||||||
|
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
|
||||||
|
``fun`` is compiled:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x):
|
||||||
|
z = -x
|
||||||
|
print(z) # Okay
|
||||||
|
return mx.exp(z)
|
||||||
|
|
||||||
|
mx.disable_compile()
|
||||||
|
fun(mx.array(5.0))
|
||||||
|
|
||||||
|
|
||||||
|
Pure Functions
|
||||||
|
--------------
|
||||||
|
|
||||||
|
Compiled functions are intended to be *pure*; that is they should not have side
|
||||||
|
effects. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
state = []
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, y):
|
||||||
|
z = x + y
|
||||||
|
state.append(z)
|
||||||
|
return mx.exp(z)
|
||||||
|
|
||||||
|
fun(mx.array(1.0), mx.array(2.0))
|
||||||
|
# Crash!
|
||||||
|
print(state)
|
||||||
|
|
||||||
|
After the first call of ``fun``, the ``state`` list will hold a placeholder
|
||||||
|
array. The placeholder does not have any data; it is only used to build the
|
||||||
|
computation graph. Printing such an array results in a crash.
|
||||||
|
|
||||||
|
You have two options to deal with this. The first option is to simply return
|
||||||
|
``state`` as an output:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
state = []
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, y):
|
||||||
|
z = x + y
|
||||||
|
state.append(z)
|
||||||
|
return mx.exp(z), state
|
||||||
|
|
||||||
|
_, state = fun(mx.array(1.0), mx.array(2.0))
|
||||||
|
# Prints [array(3, dtype=float32)]
|
||||||
|
print(state)
|
||||||
|
|
||||||
|
In some cases returning updated state can be pretty inconvenient. Hence,
|
||||||
|
:func:`compile` has a parameter to capture implicit outputs:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
state = []
|
||||||
|
|
||||||
|
# Tell compile to capture state as an output
|
||||||
|
@partial(mx.compile, outputs=state)
|
||||||
|
def fun(x, y):
|
||||||
|
z = x + y
|
||||||
|
state.append(z)
|
||||||
|
return mx.exp(z), state
|
||||||
|
|
||||||
|
fun(mx.array(1.0), mx.array(2.0))
|
||||||
|
# Prints [array(3, dtype=float32)]
|
||||||
|
print(state)
|
||||||
|
|
||||||
|
This is particularly useful for compiling a function which includes an update
|
||||||
|
to a container of arrays, as is commonly done when training the parameters of a
|
||||||
|
:class:`mlx.nn.Module`.
|
||||||
|
|
||||||
|
Compiled functions will also treat any inputs not in the parameter list as
|
||||||
|
constants. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
state = [mx.array(1.0)]
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x):
|
||||||
|
return x + state[0]
|
||||||
|
|
||||||
|
# Prints array(2, dtype=float32)
|
||||||
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
state[0] = mx.array(5.0)
|
||||||
|
|
||||||
|
# Still prints array(2, dtype=float32)
|
||||||
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
In order to have the change of state reflected in the outputs of ``fun`` you
|
||||||
|
again have two options. The first option is to simply pass ``state`` as input
|
||||||
|
to the function. In some cases this can be pretty inconvenient. Hence,
|
||||||
|
:func:`compile` also has a parameter to capture implicit inputs:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
state = [mx.array(1.0)]
|
||||||
|
|
||||||
|
# Tell compile to capture state as an input
|
||||||
|
@partial(mx.compile, inputs=state)
|
||||||
|
def fun(x):
|
||||||
|
return x + state[0]
|
||||||
|
|
||||||
|
# Prints array(2, dtype=float32)
|
||||||
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
state[0] = mx.array(5.0)
|
||||||
|
|
||||||
|
# Prints array(6, dtype=float32)
|
||||||
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
|
||||||
|
Compiling Training Graphs
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
This section will step through how to use :func:`compile` with a simple example
|
||||||
|
of a common setup: training a model with :obj:`mlx.nn.Module` using an
|
||||||
|
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
|
||||||
|
full forward, backward, and update with :func:`compile`.
|
||||||
|
|
||||||
|
To start, here is the simple example without any compilation:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
|
||||||
|
# 4 examples with 10 features each
|
||||||
|
x = mx.random.uniform(shape=(4, 10))
|
||||||
|
|
||||||
|
# 0, 1 targets
|
||||||
|
y = mx.array([0, 1, 0, 1])
|
||||||
|
|
||||||
|
# Simple linear model
|
||||||
|
model = nn.Linear(10, 1)
|
||||||
|
|
||||||
|
# SGD with momentum
|
||||||
|
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||||
|
|
||||||
|
def loss_fn(model, x, y):
|
||||||
|
logits = model(x).squeeze()
|
||||||
|
return nn.losses.binary_cross_entropy(logits, y)
|
||||||
|
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
|
||||||
|
# Perform 10 steps of gradient descent
|
||||||
|
for it in range(10):
|
||||||
|
loss, grads = loss_and_grad_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
|
To compile the update we can put it all in a function and compile it with the
|
||||||
|
appropriate input and output captures. Here's the same example but compiled:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
# 4 examples with 10 features each
|
||||||
|
x = mx.random.uniform(shape=(4, 10))
|
||||||
|
|
||||||
|
# 0, 1 targets
|
||||||
|
y = mx.array([0, 1, 0, 1])
|
||||||
|
|
||||||
|
# Simple linear model
|
||||||
|
model = nn.Linear(10, 1)
|
||||||
|
|
||||||
|
# SGD with momentum
|
||||||
|
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||||
|
|
||||||
|
def loss_fn(model, x, y):
|
||||||
|
logits = model(x).squeeze()
|
||||||
|
return nn.losses.binary_cross_entropy(logits, y)
|
||||||
|
|
||||||
|
# The state that will be captured as input and output
|
||||||
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def step(x, y):
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
loss, grads = loss_and_grad_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# Perform 10 steps of gradient descent
|
||||||
|
for it in range(10):
|
||||||
|
loss = step(x, y)
|
||||||
|
# Evaluate the model and optimizer state
|
||||||
|
mx.eval(state)
|
||||||
|
print(loss)
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If you are using a module which performs random sampling such as
|
||||||
|
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
|
||||||
|
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
|
||||||
|
optimizer.state, mx.random.state]``.
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
For more examples of compiling full training graphs checkout the `MLX
|
||||||
|
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
|
||||||
|
|
||||||
|
Transformations with Compile
|
||||||
|
----------------------------
|
||||||
|
|
||||||
|
In MLX function transformations are composable. You can apply any function
|
||||||
|
transformation to the output of any other function transformation. For more on
|
||||||
|
this, see the documentation on :ref:`function transforms
|
||||||
|
<function_transforms>`.
|
||||||
|
|
||||||
|
Compiling transformed functions works just as expected:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
grad_fn = mx.grad(mx.exp)
|
||||||
|
|
||||||
|
compiled_grad_fn = mx.compile(grad_fn)
|
||||||
|
|
||||||
|
# Prints: array(2.71828, dtype=float32)
|
||||||
|
print(grad_fn(mx.array(1.0)))
|
||||||
|
|
||||||
|
# Also prints: array(2.71828, dtype=float32)
|
||||||
|
print(compiled_grad_fn(mx.array(1.0)))
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
In order to compile as much as possible, a transformation of a compiled
|
||||||
|
function will not by default be compiled. To compile the transformed
|
||||||
|
function simply pass it through :func:`compile`.
|
||||||
|
|
||||||
|
You can also compile functions which themselves call compiled functions. A
|
||||||
|
good practice is to compile the outer most function to give :func:`compile`
|
||||||
|
the most opportunity to optimize the computation graph:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def inner(x):
|
||||||
|
return mx.exp(-mx.abs(x))
|
||||||
|
|
||||||
|
def outer(x):
|
||||||
|
inner(inner(x))
|
||||||
|
|
||||||
|
# Compiling the outer function is good to do as it will likely
|
||||||
|
# be faster even though the inner functions are compiled
|
||||||
|
fun = mx.compile(outer)
|
191
docs/src/usage/function_transforms.rst
Normal file
191
docs/src/usage/function_transforms.rst
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
.. _function_transforms:
|
||||||
|
|
||||||
|
Function Transforms
|
||||||
|
===================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
MLX uses composable function transformations for automatic differentiation,
|
||||||
|
vectorization, and compute graph optimizations. To see the complete list of
|
||||||
|
function transformations check-out the :ref:`API documentation <transforms>`.
|
||||||
|
|
||||||
|
The key idea behind composable function transformations is that every
|
||||||
|
transformation returns a function which can be further transformed.
|
||||||
|
|
||||||
|
Here is a simple example:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> dfdx = mx.grad(mx.sin)
|
||||||
|
>>> dfdx(mx.array(mx.pi))
|
||||||
|
array(-1, dtype=float32)
|
||||||
|
>>> mx.cos(mx.array(mx.pi))
|
||||||
|
array(-1, dtype=float32)
|
||||||
|
|
||||||
|
|
||||||
|
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
||||||
|
case it is the gradient of the sine function which is exactly the cosine
|
||||||
|
function. To get the second derivative you can do:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
|
||||||
|
>>> d2fdx2(mx.array(mx.pi / 2))
|
||||||
|
array(-1, dtype=float32)
|
||||||
|
>>> mx.sin(mx.array(mx.pi / 2))
|
||||||
|
array(1, dtype=float32)
|
||||||
|
|
||||||
|
Using :func:`grad` on the output of :func:`grad` is always ok. You keep
|
||||||
|
getting higher order derivatives.
|
||||||
|
|
||||||
|
Any of the MLX function transformations can be composed in any order to any
|
||||||
|
depth. See the following sections for more information on :ref:`automatic
|
||||||
|
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
|
||||||
|
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
|
||||||
|
|
||||||
|
|
||||||
|
Automatic Differentiation
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
.. _auto diff:
|
||||||
|
|
||||||
|
Automatic differentiation in MLX works on functions rather than on implicit
|
||||||
|
graphs.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If you are coming to MLX from PyTorch, you no longer need functions like
|
||||||
|
``backward``, ``zero_grad``, and ``detach``, or properties like
|
||||||
|
``requires_grad``.
|
||||||
|
|
||||||
|
The most basic example is taking the gradient of a scalar-valued function as we
|
||||||
|
saw above. You can use the :func:`grad` and :func:`value_and_grad` function to
|
||||||
|
compute gradients of more complex functions. By default these functions compute
|
||||||
|
the gradient with respect to the first argument:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def loss_fn(w, x, y):
|
||||||
|
return mx.mean(mx.square(w * x - y))
|
||||||
|
|
||||||
|
w = mx.array(1.0)
|
||||||
|
x = mx.array([0.5, -0.5])
|
||||||
|
y = mx.array([1.5, -1.5])
|
||||||
|
|
||||||
|
# Computes the gradient of loss_fn with respect to w:
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
dloss_dw = grad_fn(w, x, y)
|
||||||
|
# Prints array(-1, dtype=float32)
|
||||||
|
print(dloss_dw)
|
||||||
|
|
||||||
|
# To get the gradient with respect to x we can do:
|
||||||
|
grad_fn = mx.grad(loss_fn, argnums=1)
|
||||||
|
dloss_dx = grad_fn(w, x, y)
|
||||||
|
# Prints array([-1, 1], dtype=float32)
|
||||||
|
print(dloss_dx)
|
||||||
|
|
||||||
|
|
||||||
|
One way to get the loss and gradient is to call ``loss_fn`` followed by
|
||||||
|
``grad_fn``, but this can result in a lot of redundant work. Instead, you
|
||||||
|
should use :func:`value_and_grad`. Continuing the above example:
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Computes the gradient of loss_fn with respect to w:
|
||||||
|
loss_and_grad_fn = mx.value_and_grad(loss_fn)
|
||||||
|
loss, dloss_dw = loss_and_grad_fn(w, x, y)
|
||||||
|
|
||||||
|
# Prints array(1, dtype=float32)
|
||||||
|
print(loss)
|
||||||
|
|
||||||
|
# Prints array(-1, dtype=float32)
|
||||||
|
print(dloss_dw)
|
||||||
|
|
||||||
|
|
||||||
|
You can also take the gradient with respect to arbitrarily nested Python
|
||||||
|
containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or
|
||||||
|
:obj:`dict`).
|
||||||
|
|
||||||
|
Suppose we wanted a weight and a bias parameter in the above example. A nice
|
||||||
|
way to do that is the following:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def loss_fn(params, x, y):
|
||||||
|
w, b = params["weight"], params["bias"]
|
||||||
|
h = w * x + b
|
||||||
|
return mx.mean(mx.square(h - y))
|
||||||
|
|
||||||
|
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
||||||
|
x = mx.array([0.5, -0.5])
|
||||||
|
y = mx.array([1.5, -1.5])
|
||||||
|
|
||||||
|
# Computes the gradient of loss_fn with respect to both the
|
||||||
|
# weight and bias:
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
grads = grad_fn(params, x, y)
|
||||||
|
|
||||||
|
# Prints
|
||||||
|
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
|
||||||
|
print(grads)
|
||||||
|
|
||||||
|
Notice the tree structure of the parameters is preserved in the gradients.
|
||||||
|
|
||||||
|
In some cases you may want to stop gradients from propagating through a
|
||||||
|
part of the function. You can use the :func:`stop_gradient` for that.
|
||||||
|
|
||||||
|
|
||||||
|
Automatic Vectorization
|
||||||
|
-----------------------
|
||||||
|
|
||||||
|
.. _vmap:
|
||||||
|
|
||||||
|
Use :func:`vmap` to automate vectorizing complex functions. Here we'll go
|
||||||
|
through a basic and contrived example for the sake of clarity, but :func:`vmap`
|
||||||
|
can be quite powerful for more complex functions which are difficult to optimize
|
||||||
|
by hand.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Some operations are not yet supported with :func:`vmap`. If you encounter an error
|
||||||
|
like: ``ValueError: Primitive's vmap not implemented.`` file an `issue
|
||||||
|
<https://github.com/ml-explore/mlx/issues>`_ and include your function.
|
||||||
|
We will prioritize including it.
|
||||||
|
|
||||||
|
A naive way to add the elements from two sets of vectors is with a loop:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
xs = mx.random.uniform(shape=(4096, 100))
|
||||||
|
ys = mx.random.uniform(shape=(100, 4096))
|
||||||
|
|
||||||
|
def naive_add(xs, ys):
|
||||||
|
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
||||||
|
|
||||||
|
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Vectorize over the second dimension of x and the
|
||||||
|
# first dimension of y
|
||||||
|
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||||
|
|
||||||
|
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||||
|
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||||
|
where the vectorized axes should be in the outputs.
|
||||||
|
|
||||||
|
Let's time these two different versions:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import timeit
|
||||||
|
|
||||||
|
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||||
|
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||||
|
|
||||||
|
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
||||||
|
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
||||||
|
|
||||||
|
Of course, this operation is quite contrived. A better approach is to simply do
|
||||||
|
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
123
docs/src/usage/indexing.rst
Normal file
123
docs/src/usage/indexing.rst
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
.. _indexing:
|
||||||
|
|
||||||
|
Indexing Arrays
|
||||||
|
===============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
For the most part, indexing an MLX :obj:`array` works the same as indexing a
|
||||||
|
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
|
||||||
|
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
|
||||||
|
how that works.
|
||||||
|
|
||||||
|
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(10)
|
||||||
|
>>> arr[3]
|
||||||
|
array(3, dtype=int32)
|
||||||
|
>>> arr[-2] # negative indexing works
|
||||||
|
array(8, dtype=int32)
|
||||||
|
>>> arr[2:8:2] # start, stop, stride
|
||||||
|
array([2, 4, 6], dtype=int32)
|
||||||
|
|
||||||
|
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(8).reshape(2, 2, 2)
|
||||||
|
>>> arr[:, :, 0]
|
||||||
|
array(3, dtype=int32)
|
||||||
|
array([[0, 2],
|
||||||
|
[4, 6]], dtype=int32
|
||||||
|
>>> arr[..., 0]
|
||||||
|
array([[0, 2],
|
||||||
|
[4, 6]], dtype=int32
|
||||||
|
|
||||||
|
You can index with ``None`` to create a new axis:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(8)
|
||||||
|
>>> arr.shape
|
||||||
|
[8]
|
||||||
|
>>> arr[None].shape
|
||||||
|
[1, 8]
|
||||||
|
|
||||||
|
|
||||||
|
You can also use an :obj:`array` to index another :obj:`array`:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> arr = mx.arange(10)
|
||||||
|
>>> idx = mx.array([5, 7])
|
||||||
|
>>> arr[idx]
|
||||||
|
array([5, 7], dtype=int32)
|
||||||
|
|
||||||
|
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
|
||||||
|
works just as in NumPy.
|
||||||
|
|
||||||
|
Other functions which may be useful for indexing arrays are :func:`take` and
|
||||||
|
:func:`take_along_axis`.
|
||||||
|
|
||||||
|
Differences from NumPy
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
.. Note::
|
||||||
|
|
||||||
|
MLX indexing is different from NumPy indexing in two important ways:
|
||||||
|
|
||||||
|
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||||
|
undefined behavior.
|
||||||
|
* Boolean mask based indexing is not yet supported.
|
||||||
|
|
||||||
|
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||||
|
from the GPU. Performing bounds checking for array indices before launching the
|
||||||
|
kernel would be extremely inefficient.
|
||||||
|
|
||||||
|
Indexing with boolean masks is something that MLX may support in the future. In
|
||||||
|
general, MLX has limited support for operations for which outputs
|
||||||
|
*shapes* are dependent on input *data*. Other examples of these types of
|
||||||
|
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||||
|
single input version of :func:`numpy.where`.
|
||||||
|
|
||||||
|
In Place Updates
|
||||||
|
----------------
|
||||||
|
|
||||||
|
In place updates to indexed arrays are possible in MLX. For example:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[2] = 0
|
||||||
|
>>> a
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
Just as in NumPy, in place updates will be reflected in all references to the
|
||||||
|
same array:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> b = a
|
||||||
|
>>> b[2] = 0
|
||||||
|
>>> b
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
>>> a
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
|
expected. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, idx):
|
||||||
|
x[idx] = 2.0
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
||||||
|
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
|
||||||
|
|
||||||
|
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||||
|
and ones elsewhere.
|
144
docs/src/usage/lazy_evaluation.rst
Normal file
144
docs/src/usage/lazy_evaluation.rst
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
.. _lazy eval:
|
||||||
|
|
||||||
|
Lazy Evaluation
|
||||||
|
===============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
Why Lazy Evaluation
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
When you perform operations in MLX, no computation actually happens. Instead a
|
||||||
|
compute graph is recorded. The actual computation only happens if an
|
||||||
|
:func:`eval` is performed.
|
||||||
|
|
||||||
|
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||||
|
describe below.
|
||||||
|
|
||||||
|
Transforming Compute Graphs
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Lazy evaluation let's us record a compute graph without actually doing any
|
||||||
|
computations. This is useful for function transformations like :func:`grad` and
|
||||||
|
:func:`vmap` and graph optimizations.
|
||||||
|
|
||||||
|
Currently, MLX does not compile and rerun compute graphs. They are all
|
||||||
|
generated dynamically. However, lazy evaluation makes it much easier to
|
||||||
|
integrate compilation for future performance enhancements.
|
||||||
|
|
||||||
|
Only Compute What You Use
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
In MLX you do not need to worry as much about computing outputs that are never
|
||||||
|
used. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
a = fun1(x)
|
||||||
|
b = expensive_fun(a)
|
||||||
|
return a, b
|
||||||
|
|
||||||
|
y, _ = fun(x)
|
||||||
|
|
||||||
|
Here, we never actually compute the output of ``expensive_fun``. Use this
|
||||||
|
pattern with care though, as the graph of ``expensive_fun`` is still built, and
|
||||||
|
that has some cost associated to it.
|
||||||
|
|
||||||
|
Similarly, lazy evaluation can be beneficial for saving memory while keeping
|
||||||
|
code simple. Say you have a very large model ``Model`` derived from
|
||||||
|
:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.
|
||||||
|
Typically, this will initialize all of the weights as ``float32``, but the
|
||||||
|
initialization does not actually compute anything until you perform an
|
||||||
|
:func:`eval`. If you update the model with ``float16`` weights, your maximum
|
||||||
|
consumed memory will be half that required if eager computation was used
|
||||||
|
instead.
|
||||||
|
|
||||||
|
This pattern is simple to do in MLX thanks to lazy computation:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = Model() # no memory used yet
|
||||||
|
model.load_weights("weights_fp16.safetensors")
|
||||||
|
|
||||||
|
When to Evaluate
|
||||||
|
----------------
|
||||||
|
|
||||||
|
A common question is when to use :func:`eval`. The trade-off is between
|
||||||
|
letting graphs get too large and not batching enough useful work.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
a = a + b
|
||||||
|
mx.eval(a)
|
||||||
|
b = b * 2
|
||||||
|
mx.eval(b)
|
||||||
|
|
||||||
|
This is a bad idea because there is some fixed overhead with each graph
|
||||||
|
evaluation. On the other hand, there is some slight overhead which grows with
|
||||||
|
the compute graph size, so extremely large graphs (while computationally
|
||||||
|
correct) can be costly.
|
||||||
|
|
||||||
|
Luckily, a wide range of compute graph sizes work pretty well with MLX:
|
||||||
|
anything from a few tens of operations to many thousands of operations per
|
||||||
|
evaluation should be okay.
|
||||||
|
|
||||||
|
Most numerical computations have an iterative outer loop (e.g. the iteration in
|
||||||
|
stochastic gradient descent). A natural and usually efficient place to use
|
||||||
|
:func:`eval` is at each iteration of this outer loop.
|
||||||
|
|
||||||
|
Here is a concrete example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for batch in dataset:
|
||||||
|
|
||||||
|
# Nothing has been evaluated yet
|
||||||
|
loss, grad = value_and_grad_fn(model, batch)
|
||||||
|
|
||||||
|
# Still nothing has been evaluated
|
||||||
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
|
# Evaluate the loss and the new parameters which will
|
||||||
|
# run the full gradient computation and optimizer update
|
||||||
|
mx.eval(loss, model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
An important behavior to be aware of is when the graph will be implicitly
|
||||||
|
evaluated. Anytime you ``print`` an array, convert it to an
|
||||||
|
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||||
|
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||||
|
saving functions) will also evaluate the array.
|
||||||
|
|
||||||
|
|
||||||
|
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||||
|
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||||
|
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||||
|
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||||
|
will be a partial evaluation, computing only the forward pass.
|
||||||
|
|
||||||
|
Also, calling :func:`eval` on an array or set of arrays multiple times is
|
||||||
|
perfectly fine. This is effectively a no-op.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Using scalar arrays for control-flow will cause an evaluation.
|
||||||
|
|
||||||
|
Here is an example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
h, y = first_layer(x)
|
||||||
|
if y > 0: # An evaluation is done here!
|
||||||
|
z = second_layer_a(h)
|
||||||
|
else:
|
||||||
|
z = second_layer_b(h)
|
||||||
|
return z
|
||||||
|
|
||||||
|
Using arrays for control flow should be done with care. The above example works
|
||||||
|
and can even be used with gradient transformations. However, this can be very
|
||||||
|
inefficient if evaluations are done too frequently.
|
108
docs/src/usage/numpy.rst
Normal file
108
docs/src/usage/numpy.rst
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
.. _numpy:
|
||||||
|
|
||||||
|
Conversion to NumPy and Other Frameworks
|
||||||
|
========================================
|
||||||
|
|
||||||
|
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||||
|
Let's convert an array to NumPy and back.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = np.array(a) # copy of a
|
||||||
|
c = mx.array(b) # copy of b
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
||||||
|
``np.array(a.astype(mx.float32))``.
|
||||||
|
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
||||||
|
|
||||||
|
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
a_view = np.array(a, copy=False)
|
||||||
|
print(a_view.flags.owndata) # False
|
||||||
|
a_view[0] = 1
|
||||||
|
print(a[0].item()) # 1
|
||||||
|
|
||||||
|
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
||||||
|
This means writing to the view is reflected in the original array.
|
||||||
|
|
||||||
|
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
||||||
|
|
||||||
|
Let's demonstrate this in an example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
x_view = np.array(x, copy=False)
|
||||||
|
x_view[:] *= x_view # modify memory without telling mx
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
x = mx.array([3.0])
|
||||||
|
y, df = mx.value_and_grad(f)(x)
|
||||||
|
print("f(x) = x² =", y.item()) # 9.0
|
||||||
|
print("f'(x) = 2x !=", df.item()) # 1.0
|
||||||
|
|
||||||
|
|
||||||
|
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||||
|
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
||||||
|
representing the gradient of the sum operation alone.
|
||||||
|
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
||||||
|
It's important to note that a similar issue arises during array conversion and copying.
|
||||||
|
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||||
|
even though no in-place operations on MLX memory are executed.
|
||||||
|
|
||||||
|
PyTorch
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||||
|
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||||
|
|
||||||
|
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = torch.tensor(memoryview(a))
|
||||||
|
c = mx.array(b.numpy())
|
||||||
|
|
||||||
|
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
||||||
|
|
||||||
|
JAX
|
||||||
|
---
|
||||||
|
JAX fully supports the buffer protocol.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = jnp.array(a)
|
||||||
|
c = mx.array(b)
|
||||||
|
|
||||||
|
TensorFlow
|
||||||
|
----------
|
||||||
|
|
||||||
|
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
a = mx.arange(3)
|
||||||
|
b = tf.constant(memoryview(a))
|
||||||
|
c = mx.array(b)
|
@@ -40,6 +40,9 @@ automatically evaluate the array.
|
|||||||
>> np.array(c) # Also evaluates c
|
>> np.array(c) # Also evaluates c
|
||||||
array([2., 4., 6., 8.], dtype=float32)
|
array([2., 4., 6., 8.], dtype=float32)
|
||||||
|
|
||||||
|
|
||||||
|
See the page on :ref:`Lazy Evaluation <lazy eval>` for more details.
|
||||||
|
|
||||||
Function and Graph Transformations
|
Function and Graph Transformations
|
||||||
----------------------------------
|
----------------------------------
|
||||||
|
|
||||||
@@ -62,10 +65,3 @@ and :func:`jvp` for Jacobian-vector products.
|
|||||||
|
|
||||||
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||||
gradient with respect to the function's input.
|
gradient with respect to the function's input.
|
||||||
|
|
||||||
|
|
||||||
Devices and Streams
|
|
||||||
-------------------
|
|
||||||
|
|
||||||
|
|
||||||
|
|
81
docs/src/usage/saving_and_loading.rst
Normal file
81
docs/src/usage/saving_and_loading.rst
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
.. _saving_and_loading:
|
||||||
|
|
||||||
|
Saving and Loading Arrays
|
||||||
|
=========================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
MLX supports multiple array serialization formats.
|
||||||
|
|
||||||
|
.. list-table:: Serialization Formats
|
||||||
|
:widths: 20 8 25 25
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Format
|
||||||
|
- Extension
|
||||||
|
- Function
|
||||||
|
- Notes
|
||||||
|
* - NumPy
|
||||||
|
- ``.npy``
|
||||||
|
- :func:`save`
|
||||||
|
- Single arrays only
|
||||||
|
* - NumPy archive
|
||||||
|
- ``.npz``
|
||||||
|
- :func:`savez` and :func:`savez_compressed`
|
||||||
|
- Multiple arrays
|
||||||
|
* - Safetensors
|
||||||
|
- ``.safetensors``
|
||||||
|
- :func:`save_safetensors`
|
||||||
|
- Multiple arrays
|
||||||
|
* - GGUF
|
||||||
|
- ``.gguf``
|
||||||
|
- :func:`save_gguf`
|
||||||
|
- Multiple arrays
|
||||||
|
|
||||||
|
The :func:`load` function will load any of the supported serialization
|
||||||
|
formats. It determines the format from the extensions. The output of
|
||||||
|
:func:`load` depends on the format.
|
||||||
|
|
||||||
|
Here's an example of saving a single array to a file:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0])
|
||||||
|
>>> mx.save("array", a)
|
||||||
|
|
||||||
|
The array ``a`` will be saved in the file ``array.npy`` (notice the extension
|
||||||
|
is automatically added). Including the extension is optional; if it is missing
|
||||||
|
it will be added. You can load the array with:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> mx.load("array.npy", a)
|
||||||
|
array([1], dtype=float32)
|
||||||
|
|
||||||
|
Here's an example of saving several arrays to a single file:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0])
|
||||||
|
>>> b = mx.array([2.0])
|
||||||
|
>>> mx.savez("arrays", a, b=b)
|
||||||
|
|
||||||
|
For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays
|
||||||
|
as arguments. If the keywords are missing, then default names will be
|
||||||
|
provided. This can be loaded with:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> mx.load("arrays.npz")
|
||||||
|
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
|
||||||
|
|
||||||
|
In this case :func:`load` returns a dictionary of names to arrays.
|
||||||
|
|
||||||
|
The functions :func:`save_safetensors` and :func:`save_gguf` are similar to
|
||||||
|
:func:`savez`, but they take as input a :obj:`dict` of string names to arrays:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0])
|
||||||
|
>>> b = mx.array([2.0])
|
||||||
|
>>> mx.save_safetensors("arrays", {"a": a, "b": b})
|
78
docs/src/usage/unified_memory.rst
Normal file
78
docs/src/usage/unified_memory.rst
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
.. _unified_memory:
|
||||||
|
|
||||||
|
Unified Memory
|
||||||
|
==============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
Apple silicon has a unified memory architecture. The CPU and GPU have direct
|
||||||
|
access to the same memory pool. MLX is designed to take advantage of that.
|
||||||
|
|
||||||
|
Concretely, when you make an array in MLX you don't have to specify its location:
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
a = mx.random.normal((100,))
|
||||||
|
b = mx.random.normal((100,))
|
||||||
|
|
||||||
|
Both ``a`` and ``b`` live in unified memory.
|
||||||
|
|
||||||
|
In MLX, rather than moving arrays to devices, you specify the device when you
|
||||||
|
run the operation. Any device can perform any operation on ``a`` and ``b``
|
||||||
|
without needing to move them from one memory location to another. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
mx.add(a, b, stream=mx.cpu)
|
||||||
|
mx.add(a, b, stream=mx.gpu)
|
||||||
|
|
||||||
|
In the above, both the CPU and the GPU will perform the same add
|
||||||
|
operation. The operations can (and likely will) be run in parallel since
|
||||||
|
there are no dependencies between them. See :ref:`using_streams` for more
|
||||||
|
information the semantics of streams in MLX.
|
||||||
|
|
||||||
|
In the above ``add`` example, there are no dependencies between operations, so
|
||||||
|
there is no possibility for race conditions. If there are dependencies, the
|
||||||
|
MLX scheduler will automatically manage them. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
c = mx.add(a, b, stream=mx.cpu)
|
||||||
|
d = mx.add(a, c, stream=mx.gpu)
|
||||||
|
|
||||||
|
In the above case, the second ``add`` runs on the GPU but it depends on the
|
||||||
|
output of the first ``add`` which is running on the CPU. MLX will
|
||||||
|
automatically insert a dependency between the two streams so that the second
|
||||||
|
``add`` only starts executing after the first is complete and ``c`` is
|
||||||
|
available.
|
||||||
|
|
||||||
|
A Simple Example
|
||||||
|
~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Here is a more interesting (albeit slightly contrived example) of how unified
|
||||||
|
memory can be helpful. Suppose we have the following computation:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(a, b, d1, d2):
|
||||||
|
x = mx.matmul(a, b, stream=d1)
|
||||||
|
for _ in range(500):
|
||||||
|
b = mx.exp(b, stream=d2)
|
||||||
|
return x, b
|
||||||
|
|
||||||
|
which we want to run with the following arguments:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
a = mx.random.uniform(shape=(4096, 512))
|
||||||
|
b = mx.random.uniform(shape=(512, 4))
|
||||||
|
|
||||||
|
The first ``matmul`` operation is a good fit for the GPU since it's more
|
||||||
|
compute dense. The second sequence of operations are a better fit for the CPU,
|
||||||
|
since they are very small and would probably be overhead bound on the GPU.
|
||||||
|
|
||||||
|
If we time the computation fully on the GPU, we get 2.8 milliseconds. But if we
|
||||||
|
run the computation with ``d1=mx.gpu`` and ``d2=mx.cpu``, then the time is only
|
||||||
|
about 1.4 milliseconds, about twice as fast. These times were measured on an M1
|
||||||
|
Max.
|
@@ -1,3 +1,5 @@
|
|||||||
|
.. _using_streams:
|
||||||
|
|
||||||
Using Streams
|
Using Streams
|
||||||
=============
|
=============
|
||||||
|
|
@@ -57,7 +57,7 @@ void array_basics() {
|
|||||||
assert(z.shape(0) == 2);
|
assert(z.shape(0) == 2);
|
||||||
assert(z.shape(1) == 2);
|
assert(z.shape(1) == 2);
|
||||||
|
|
||||||
// To actually run the compuation you must evaluate `z`.
|
// To actually run the computation you must evaluate `z`.
|
||||||
// Under the hood, mlx records operations in a graph.
|
// Under the hood, mlx records operations in a graph.
|
||||||
// The variable `z` is a node in the graph which points to its operation
|
// The variable `z` is a node in the graph which points to its operation
|
||||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
cmake_minimum_required(VERSION 3.24)
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
project(mlx_sample_extensions LANGUAGES CXX)
|
project(mlx_sample_extensions LANGUAGES CXX)
|
||||||
|
|
||||||
@@ -63,4 +63,4 @@ target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
|||||||
|
|
||||||
if(BUILD_SHARED_LIBS)
|
if(BUILD_SHARED_LIBS)
|
||||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||||
endif()
|
endif()
|
||||||
|
@@ -26,7 +26,7 @@ namespace mlx::core {
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Scale and sum two vectors elementwise
|
* Scale and sum two vectors element-wise
|
||||||
* z = alpha * x + beta * y
|
* z = alpha * x + beta * y
|
||||||
*
|
*
|
||||||
* Follow numpy style broadcasting between x and y
|
* Follow numpy style broadcasting between x and y
|
||||||
@@ -91,21 +91,24 @@ void axpby_impl(
|
|||||||
T alpha = static_cast<T>(alpha_);
|
T alpha = static_cast<T>(alpha_);
|
||||||
T beta = static_cast<T>(beta_);
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
// Do the elementwise operation for each output
|
// Do the element-wise operation for each output
|
||||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||||
// Map linear indices to offsets in x and y
|
// Map linear indices to offsets in x and y
|
||||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||||
|
|
||||||
// We allocate the output to be contiguous and regularly strided
|
// We allocate the output to be contiguous and regularly strided
|
||||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
// (defaults to row major) and hence it doesn't need additional mapping
|
||||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval(
|
||||||
// Check the inputs (registered in the op while contructing the out array)
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& out_arr) {
|
||||||
|
auto out = out_arr[0];
|
||||||
|
// Check the inputs (registered in the op while constructing the out array)
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
@@ -175,7 +178,10 @@ void axpby_impl_accelerate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Evaluate primitive on CPU using accelerate specializations */
|
/** Evaluate primitive on CPU using accelerate specializations */
|
||||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outarr) {
|
||||||
|
auto out = outarr[0];
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
@@ -189,13 +195,15 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to common backend if specializations are not available
|
// Fall back to common backend if specializations are not available
|
||||||
eval(inputs, out);
|
eval(inputs, outarr);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // Accelerate not avaliable
|
#else // Accelerate not available
|
||||||
|
|
||||||
/** Evaluate primitive on CPU falling back to common backend */
|
/** Evaluate primitive on CPU falling back to common backend */
|
||||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& out) {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,8 +216,11 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
#ifdef _METAL_
|
#ifdef _METAL_
|
||||||
|
|
||||||
/** Evaluate primitive on GPU */
|
/** Evaluate primitive on GPU */
|
||||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outarr) {
|
||||||
// Prepare inputs
|
// Prepare inputs
|
||||||
|
auto out = outarr[0];
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
@@ -254,7 +265,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
// those in the kernel decelaration at axpby.metal
|
// those in the kernel declaration at axpby.metal
|
||||||
int ndim = out.ndim();
|
int ndim = out.ndim();
|
||||||
size_t nelem = out.size();
|
size_t nelem = out.size();
|
||||||
|
|
||||||
@@ -287,7 +298,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Fix the 3D size of the launch grid (in terms of threads)
|
// Fix the 3D size of the launch grid (in terms of threads)
|
||||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||||
|
|
||||||
// Launch the grid with the given number of threads divded among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
@@ -295,7 +306,9 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
#else // Metal is not available
|
#else // Metal is not available
|
||||||
|
|
||||||
/** Fail evaluation on GPU */
|
/** Fail evaluation on GPU */
|
||||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& out) {
|
||||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,13 +319,13 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
array Axpby::jvp(
|
std::vector<array> Axpby::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the the primitive can built with ops
|
// The jvp transform on the primitive can built with ops
|
||||||
// that are scheduled on the same stream as the primtive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
// jvp is just the tangent scaled by alpha
|
// jvp is just the tangent scaled by alpha
|
||||||
@@ -321,32 +334,33 @@ array Axpby::jvp(
|
|||||||
if (argnums.size() > 1) {
|
if (argnums.size() > 1) {
|
||||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||||
auto scale_arr = array(scale, tangents[0].dtype());
|
auto scale_arr = array(scale, tangents[0].dtype());
|
||||||
return multiply(scale_arr, tangents[0], stream());
|
return {multiply(scale_arr, tangents[0], stream())};
|
||||||
}
|
}
|
||||||
// If, argnums = {0, 1}, we take contributions from both
|
// If, argnums = {0, 1}, we take contributions from both
|
||||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
else {
|
else {
|
||||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> Axpby::vjp(
|
std::vector<array> Axpby::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>&) {
|
||||||
// Reverse mode diff
|
// Reverse mode diff
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
auto scale = arg == 0 ? alpha_ : beta_;
|
auto scale = arg == 0 ? alpha_ : beta_;
|
||||||
auto scale_arr = array(scale, cotan.dtype());
|
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||||
}
|
}
|
||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Vectorize primitve along given axis */
|
/** Vectorize primitive along given axis */
|
||||||
std::pair<array, int> Axpby::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||||
|
@@ -12,7 +12,7 @@ namespace mlx::core {
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Scale and sum two vectors elementwise
|
* Scale and sum two vectors element-wise
|
||||||
* z = alpha * x + beta * y
|
* z = alpha * x + beta * y
|
||||||
*
|
*
|
||||||
* Follow numpy style broadcasting between x and y
|
* Follow numpy style broadcasting between x and y
|
||||||
@@ -39,14 +39,16 @@ class Axpby : public Primitive {
|
|||||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||||
* for the given inputs and populate the output array.
|
* for the given inputs and populate the output array.
|
||||||
*
|
*
|
||||||
* To avoid unecessary allocations, the evaluation function
|
* To avoid unnecessary allocations, the evaluation function
|
||||||
* is responsible for allocating space for the array.
|
* is responsible for allocating space for the array.
|
||||||
*/
|
*/
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
|
||||||
|
override;
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
array jvp(
|
std::vector<array> jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) override;
|
const std::vector<int>& argnums) override;
|
||||||
@@ -54,16 +56,17 @@ class Axpby : public Primitive {
|
|||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> vjp(
|
std::vector<array> vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums) override;
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The primitive must know how to vectorize itself accross
|
* The primitive must know how to vectorize itself across
|
||||||
* the given axes. The output is a pair containing the array
|
* the given axes. The output is a pair containing the array
|
||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
std::pair<array, int> vmap(
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
@@ -80,7 +83,7 @@ class Axpby : public Primitive {
|
|||||||
float beta_;
|
float beta_;
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
@@ -59,5 +59,5 @@ template <typename T>
|
|||||||
|
|
||||||
instantiate_axpby(float32, float);
|
instantiate_axpby(float32, float);
|
||||||
instantiate_axpby(float16, half);
|
instantiate_axpby(float16, half);
|
||||||
instantiate_axpby(bflot16, bfloat16_t);
|
instantiate_axpby(bfloat16, bfloat16_t);
|
||||||
instantiate_axpby(complex64, complex64_t);
|
instantiate_axpby(complex64, complex64_t);
|
@@ -23,7 +23,7 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = py::none(),
|
"stream"_a = py::none(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Scale and sum two vectors elementwise
|
Scale and sum two vectors element-wise
|
||||||
``z = alpha * x + beta * y``
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
Follows numpy style broadcasting between ``x`` and ``y``
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .mlx_sample_extensions import *
|
from .mlx_sample_extensions import *
|
||||||
|
3
examples/extensions/pyproject.toml
Normal file
3
examples/extensions/pyproject.toml
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
@@ -1,8 +1,9 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from mlx import extension
|
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
|
from mlx import extension
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
setup(
|
setup(
|
||||||
name="mlx_sample_extensions",
|
name="mlx_sample_extensions",
|
||||||
@@ -14,5 +15,5 @@ if __name__ == "__main__":
|
|||||||
package_dir={"": "."},
|
package_dir={"": "."},
|
||||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.8",
|
||||||
)
|
)
|
||||||
|
@@ -1,8 +1,9 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
num_features = 100
|
num_features = 100
|
||||||
num_examples = 1_000
|
num_examples = 1_000
|
||||||
num_iters = 10_000
|
num_iters = 10_000
|
||||||
@@ -40,6 +41,6 @@ error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
|
|||||||
throughput = num_iters / (toc - tic)
|
throughput = num_iters / (toc - tic)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
|
f"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, "
|
||||||
f"Throughput {throughput:.5f} (it/s)"
|
f"Throughput {throughput:.5f} (it/s)"
|
||||||
)
|
)
|
||||||
|
@@ -1,8 +1,9 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
num_features = 100
|
num_features = 100
|
||||||
num_examples = 1_000
|
num_examples = 1_000
|
||||||
num_iters = 10_000
|
num_iters = 10_000
|
||||||
|
@@ -3,23 +3,25 @@ target_sources(
|
|||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||||
)
|
)
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||||
if (MLX_BUILD_ACCELERATE)
|
if (MLX_BUILD_ACCELERATE)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(
|
||||||
|
@@ -9,7 +9,7 @@
|
|||||||
namespace mlx::core::allocator {
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
Buffer malloc(size_t size) {
|
Buffer malloc(size_t size) {
|
||||||
auto buffer = allocator().malloc(size);
|
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||||
if (size && !buffer.ptr()) {
|
if (size && !buffer.ptr()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
@@ -22,7 +22,7 @@ void free(Buffer buffer) {
|
|||||||
return allocator().free(buffer);
|
return allocator().free(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size) {
|
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||||
return Buffer{std::malloc(size)};
|
return Buffer{std::malloc(size)};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,6 +38,11 @@ Buffer malloc_or_wait(size_t size) {
|
|||||||
buffer = allocator().malloc(size);
|
buffer = allocator().malloc(size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try swapping if needed
|
||||||
|
if (size && !buffer.ptr()) {
|
||||||
|
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||||
|
}
|
||||||
|
|
||||||
if (size && !buffer.ptr()) {
|
if (size && !buffer.ptr()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||||
|
@@ -37,9 +37,9 @@ void free(Buffer buffer);
|
|||||||
Buffer malloc_or_wait(size_t size);
|
Buffer malloc_or_wait(size_t size);
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base clase for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) = 0;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
|
|
||||||
Allocator() = default;
|
Allocator() = default;
|
||||||
@@ -55,7 +55,7 @@ Allocator& allocator();
|
|||||||
class CommonAllocator : public Allocator {
|
class CommonAllocator : public Allocator {
|
||||||
/** A general CPU allocator. */
|
/** A general CPU allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) override;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
106
mlx/array.cpp
106
mlx/array.cpp
@@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
@@ -6,6 +6,7 @@
|
|||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -21,6 +22,12 @@ std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
|||||||
return {cum_prod, strides};
|
return {cum_prod, strides};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Return true if we are currently performing a function transformation in
|
||||||
|
* order to keep the graph when evaluating tracer arrays. */
|
||||||
|
bool in_tracing() {
|
||||||
|
return detail::InTracing::in_tracing();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||||
@@ -32,7 +39,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
|||||||
array::array(
|
array::array(
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::unique_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
const std::vector<array>& inputs)
|
const std::vector<array>& inputs)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
shape,
|
shape,
|
||||||
@@ -40,6 +47,34 @@ array::array(
|
|||||||
std::move(primitive),
|
std::move(primitive),
|
||||||
inputs)) {}
|
inputs)) {}
|
||||||
|
|
||||||
|
array::array(
|
||||||
|
std::vector<int> shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
std::vector<array>&& inputs)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
|
std::move(shape),
|
||||||
|
dtype,
|
||||||
|
std::move(primitive),
|
||||||
|
std::move(inputs))) {}
|
||||||
|
|
||||||
|
std::vector<array> array::make_arrays(
|
||||||
|
const std::vector<std::vector<int>>& shapes,
|
||||||
|
const std::vector<Dtype>& dtypes,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
std::vector<array> outputs;
|
||||||
|
for (int i = 0; i < shapes.size(); ++i) {
|
||||||
|
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < outputs.size(); ++i) {
|
||||||
|
auto siblings = outputs;
|
||||||
|
siblings.erase(siblings.begin() + i);
|
||||||
|
outputs[i].set_siblings(std::move(siblings), i);
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
array::array(std::initializer_list<float> data)
|
array::array(std::initializer_list<float> data)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
std::vector<int>{static_cast<int>(data.size())},
|
std::vector<int>{static_cast<int>(data.size())},
|
||||||
@@ -47,6 +82,13 @@ array::array(std::initializer_list<float> data)
|
|||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
|
std::vector<int>{static_cast<int>(data.size())},
|
||||||
|
dtype)) {
|
||||||
|
init(data.begin());
|
||||||
|
}
|
||||||
|
|
||||||
/* Build an array from a shared buffer */
|
/* Build an array from a shared buffer */
|
||||||
array::array(
|
array::array(
|
||||||
allocator::Buffer data,
|
allocator::Buffer data,
|
||||||
@@ -58,12 +100,26 @@ array::array(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void array::detach() {
|
void array::detach() {
|
||||||
|
for (auto& s : array_desc_->siblings) {
|
||||||
|
s.array_desc_->inputs.clear();
|
||||||
|
s.array_desc_->siblings.clear();
|
||||||
|
s.array_desc_->position = 0;
|
||||||
|
s.array_desc_->depth = 0;
|
||||||
|
s.array_desc_->primitive = nullptr;
|
||||||
|
}
|
||||||
array_desc_->inputs.clear();
|
array_desc_->inputs.clear();
|
||||||
|
array_desc_->siblings.clear();
|
||||||
|
array_desc_->position = 0;
|
||||||
|
array_desc_->depth = 0;
|
||||||
array_desc_->primitive = nullptr;
|
array_desc_->primitive = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::eval(bool retain_graph /* = false */) {
|
void array::eval() {
|
||||||
mlx::core::eval({*this}, retain_graph);
|
mlx::core::eval({*this});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool array::is_tracer() const {
|
||||||
|
return array_desc_->is_tracer && in_tracing();
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||||
@@ -108,6 +164,14 @@ void array::copy_shared_buffer(const array& other) {
|
|||||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void array::move_shared_buffer(array other) {
|
||||||
|
array_desc_->data = std::move(other.array_desc_->data);
|
||||||
|
array_desc_->strides = other.strides();
|
||||||
|
array_desc_->flags = other.flags();
|
||||||
|
array_desc_->data_size = other.data_size();
|
||||||
|
array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||||
: shape(shape), dtype(dtype) {
|
: shape(shape), dtype(dtype) {
|
||||||
std::tie(size, strides) = cum_prod(shape);
|
std::tie(size, strides) = cum_prod(shape);
|
||||||
@@ -116,21 +180,43 @@ array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
|||||||
array::ArrayDesc::ArrayDesc(
|
array::ArrayDesc::ArrayDesc(
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::unique_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
const std::vector<array>& inputs)
|
const std::vector<array>& inputs)
|
||||||
: shape(shape),
|
: shape(shape),
|
||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
primitive(std::move(primitive)),
|
primitive(std::move(primitive)),
|
||||||
inputs(inputs) {
|
inputs(inputs) {
|
||||||
std::tie(size, strides) = cum_prod(shape);
|
std::tie(size, strides) = cum_prod(this->shape);
|
||||||
for (auto& in : inputs) {
|
for (auto& in : this->inputs) {
|
||||||
is_tracer |= in.is_tracer();
|
is_tracer |= in.is_tracer();
|
||||||
|
depth = std::max(in.graph_depth(), depth);
|
||||||
}
|
}
|
||||||
|
depth++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Needed because the Primitive type used in array.h is incomplete and the
|
array::ArrayDesc::ArrayDesc(
|
||||||
// compiler needs to see the call to the desctructor after the type is complete.
|
std::vector<int>&& shape,
|
||||||
array::ArrayDesc::~ArrayDesc() = default;
|
Dtype dtype,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
std::vector<array>&& inputs)
|
||||||
|
: shape(std::move(shape)),
|
||||||
|
dtype(dtype),
|
||||||
|
primitive(std::move(primitive)),
|
||||||
|
inputs(std::move(inputs)) {
|
||||||
|
std::tie(size, strides) = cum_prod(this->shape);
|
||||||
|
for (auto& in : this->inputs) {
|
||||||
|
is_tracer |= in.is_tracer();
|
||||||
|
depth = std::max(in.graph_depth(), depth);
|
||||||
|
}
|
||||||
|
depth++;
|
||||||
|
}
|
||||||
|
|
||||||
|
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||||
|
: arr(arr), idx(idx) {
|
||||||
|
if (arr.ndim() == 0) {
|
||||||
|
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||||
auto start = std::vector<int>(arr.ndim(), 0);
|
auto start = std::vector<int>(arr.ndim(), 0);
|
||||||
|
127
mlx/array.h
127
mlx/array.h
@@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
@@ -42,6 +41,9 @@ class array {
|
|||||||
/* Special case so empty lists default to float32. */
|
/* Special case so empty lists default to float32. */
|
||||||
array(std::initializer_list<float> data);
|
array(std::initializer_list<float> data);
|
||||||
|
|
||||||
|
/* Special case so array({}, type) is an empty array. */
|
||||||
|
array(std::initializer_list<int> data, Dtype dtype);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array(
|
array(
|
||||||
std::initializer_list<T> data,
|
std::initializer_list<T> data,
|
||||||
@@ -116,11 +118,14 @@ class array {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/** Evaluate the array. */
|
/** Evaluate the array. */
|
||||||
void eval(bool retain_graph = false);
|
void eval();
|
||||||
|
|
||||||
/** Get the value from a scalar array. */
|
/** Get the value from a scalar array. */
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T item(bool retain_graph = false);
|
T item();
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T item() const;
|
||||||
|
|
||||||
struct ArrayIterator {
|
struct ArrayIterator {
|
||||||
using iterator_category = std::random_access_iterator_tag;
|
using iterator_category = std::random_access_iterator_tag;
|
||||||
@@ -128,11 +133,7 @@ class array {
|
|||||||
using value_type = const array;
|
using value_type = const array;
|
||||||
using reference = value_type;
|
using reference = value_type;
|
||||||
|
|
||||||
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
|
explicit ArrayIterator(const array& arr, int idx = 0);
|
||||||
if (arr.ndim() == 0) {
|
|
||||||
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
reference operator*() const;
|
reference operator*() const;
|
||||||
|
|
||||||
@@ -154,8 +155,8 @@ class array {
|
|||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int idx;
|
|
||||||
const array& arr;
|
const array& arr;
|
||||||
|
int idx;
|
||||||
};
|
};
|
||||||
|
|
||||||
ArrayIterator begin() const {
|
ArrayIterator begin() const {
|
||||||
@@ -174,7 +175,19 @@ class array {
|
|||||||
array(
|
array(
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::unique_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
array(
|
||||||
|
std::vector<int> shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
std::vector<array>&& inputs);
|
||||||
|
|
||||||
|
static std::vector<array> make_arrays(
|
||||||
|
const std::vector<std::vector<int>>& shapes,
|
||||||
|
const std::vector<Dtype>& dtypes,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
const std::vector<array>& inputs);
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
/** A unique identifier for an array. */
|
/** A unique identifier for an array. */
|
||||||
@@ -182,6 +195,11 @@ class array {
|
|||||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** A unique identifier for an arrays primitive. */
|
||||||
|
std::uintptr_t primitive_id() const {
|
||||||
|
return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
|
||||||
|
}
|
||||||
|
|
||||||
struct Data {
|
struct Data {
|
||||||
allocator::Buffer buffer;
|
allocator::Buffer buffer;
|
||||||
deleter_t d;
|
deleter_t d;
|
||||||
@@ -209,6 +227,11 @@ class array {
|
|||||||
return *(array_desc_->primitive);
|
return *(array_desc_->primitive);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** A shared pointer to the array's primitive. */
|
||||||
|
std::shared_ptr<Primitive>& primitive_ptr() const {
|
||||||
|
return array_desc_->primitive;
|
||||||
|
};
|
||||||
|
|
||||||
/** Check if the array has an attached primitive or is a leaf node. */
|
/** Check if the array has an attached primitive or is a leaf node. */
|
||||||
bool has_primitive() const {
|
bool has_primitive() const {
|
||||||
return array_desc_->primitive != nullptr;
|
return array_desc_->primitive != nullptr;
|
||||||
@@ -219,12 +242,42 @@ class array {
|
|||||||
return array_desc_->inputs;
|
return array_desc_->inputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** A non-const reference to the array's inputs so that they can be used to
|
std::vector<array>& inputs() {
|
||||||
* edit the graph. */
|
|
||||||
std::vector<array>& editable_inputs() {
|
|
||||||
return array_desc_->inputs;
|
return array_desc_->inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** True indicates the arrays buffer is safe to reuse */
|
||||||
|
bool is_donatable() const {
|
||||||
|
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The array's siblings. */
|
||||||
|
const std::vector<array>& siblings() const {
|
||||||
|
return array_desc_->siblings;
|
||||||
|
};
|
||||||
|
|
||||||
|
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||||
|
array_desc_->siblings = std::move(siblings);
|
||||||
|
array_desc_->position = position;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The outputs of the array's primitive (i.e. this array and
|
||||||
|
* its siblings) in the order the primitive expects. */
|
||||||
|
std::vector<array> outputs() const {
|
||||||
|
auto idx = array_desc_->position;
|
||||||
|
std::vector<array> outputs;
|
||||||
|
outputs.reserve(siblings().size() + 1);
|
||||||
|
outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
|
||||||
|
outputs.push_back(*this);
|
||||||
|
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
|
||||||
|
return outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
|
||||||
|
uint16_t graph_depth() const {
|
||||||
|
return array_desc_->depth;
|
||||||
|
}
|
||||||
|
|
||||||
/** Detach the array from the graph. */
|
/** Detach the array from the graph. */
|
||||||
void detach();
|
void detach();
|
||||||
|
|
||||||
@@ -245,6 +298,12 @@ class array {
|
|||||||
return array_desc_->data->buffer;
|
return array_desc_->data->buffer;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Return a copy of the shared pointer
|
||||||
|
// to the array::Data struct
|
||||||
|
std::shared_ptr<Data> data_shared_ptr() const {
|
||||||
|
return array_desc_->data;
|
||||||
|
}
|
||||||
|
// Return a raw pointer to the arrays data
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
return static_cast<T*>(array_desc_->data_ptr);
|
return static_cast<T*>(array_desc_->data_ptr);
|
||||||
@@ -265,9 +324,7 @@ class array {
|
|||||||
array_desc_->is_tracer = is_tracer;
|
array_desc_->is_tracer = is_tracer;
|
||||||
}
|
}
|
||||||
// Check if the array is a tracer array
|
// Check if the array is a tracer array
|
||||||
bool is_tracer() const {
|
bool is_tracer() const;
|
||||||
return array_desc_->is_tracer;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||||
|
|
||||||
@@ -287,6 +344,8 @@ class array {
|
|||||||
|
|
||||||
void copy_shared_buffer(const array& other);
|
void copy_shared_buffer(const array& other);
|
||||||
|
|
||||||
|
void move_shared_buffer(array other);
|
||||||
|
|
||||||
void overwrite_descriptor(const array& other) {
|
void overwrite_descriptor(const array& other) {
|
||||||
array_desc_ = other.array_desc_;
|
array_desc_ = other.array_desc_;
|
||||||
}
|
}
|
||||||
@@ -301,7 +360,7 @@ class array {
|
|||||||
std::vector<size_t> strides;
|
std::vector<size_t> strides;
|
||||||
size_t size;
|
size_t size;
|
||||||
Dtype dtype;
|
Dtype dtype;
|
||||||
std::unique_ptr<Primitive> primitive{nullptr};
|
std::shared_ptr<Primitive> primitive{nullptr};
|
||||||
|
|
||||||
// Indicates an array is being used in a graph transform
|
// Indicates an array is being used in a graph transform
|
||||||
// and should not be detached from the graph
|
// and should not be detached from the graph
|
||||||
@@ -323,22 +382,34 @@ class array {
|
|||||||
Flags flags;
|
Flags flags;
|
||||||
|
|
||||||
std::vector<array> inputs;
|
std::vector<array> inputs;
|
||||||
|
// An array to keep track of the siblings from a multi-output
|
||||||
|
// primitive.
|
||||||
|
std::vector<array> siblings;
|
||||||
|
// The arrays position in the output list
|
||||||
|
uint32_t position{0};
|
||||||
|
|
||||||
|
// The depth of the array in the graph.
|
||||||
|
uint16_t depth{0};
|
||||||
|
|
||||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||||
|
|
||||||
explicit ArrayDesc(
|
explicit ArrayDesc(
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::unique_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
const std::vector<array>& inputs);
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
~ArrayDesc();
|
explicit ArrayDesc(
|
||||||
|
std::vector<int>&& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::shared_ptr<Primitive> primitive,
|
||||||
|
std::vector<array>&& inputs);
|
||||||
};
|
};
|
||||||
|
|
||||||
// The ArrayDesc contains the details of the materialized array including the
|
// The ArrayDesc contains the details of the materialized array including the
|
||||||
// shape, strides, the data type. It also includes
|
// shape, strides, the data type. It also includes
|
||||||
// the primitive which knows how to compute the array's data from its inputs
|
// the primitive which knows how to compute the array's data from its inputs
|
||||||
// and a the list of array's inputs for the primitive.
|
// and the list of array's inputs for the primitive.
|
||||||
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
|
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -381,11 +452,23 @@ array::array(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T array::item(bool retain_graph /* = false */) {
|
T array::item() {
|
||||||
if (size() != 1) {
|
if (size() != 1) {
|
||||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||||
}
|
}
|
||||||
eval(retain_graph);
|
eval();
|
||||||
|
return *data<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T array::item() const {
|
||||||
|
if (size() != 1) {
|
||||||
|
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||||
|
}
|
||||||
|
if (!is_evaled()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"item() const can only be called on evaled arrays");
|
||||||
|
}
|
||||||
return *data<T>();
|
return *data<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -4,6 +4,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
)
|
)
|
||||||
|
@@ -29,12 +29,16 @@ std::tuple<bool, size_t, array> check_transpose(const array& arr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
inline void matmul_cblas_general(
|
||||||
|
const array& a_pre,
|
||||||
|
const array& b_pre,
|
||||||
|
array& out,
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f) {
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[matmul_cblas] on CPU currently only supports float32");
|
"[matmul_cblas] on CPU currently only supports float32");
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||||
@@ -42,6 +46,14 @@ inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
|||||||
size_t N = b.shape(-1);
|
size_t N = b.shape(-1);
|
||||||
size_t K = a.shape(-1);
|
size_t K = a.shape(-1);
|
||||||
|
|
||||||
|
if (M == 0 || N == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (K == 0) {
|
||||||
|
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||||
cblas_sgemm(
|
cblas_sgemm(
|
||||||
CblasRowMajor,
|
CblasRowMajor,
|
||||||
@@ -50,21 +62,34 @@ inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
|||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
1.0f, // alpha
|
alpha, // alpha
|
||||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||||
lda,
|
lda,
|
||||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||||
ldb,
|
ldb,
|
||||||
0.0f, // beta
|
beta, // beta
|
||||||
out.data<float>() + M * N * i,
|
out.data<float>() + M * N * i,
|
||||||
out.shape(-1) // ldc
|
out.shape(-1) // ldc
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
||||||
// TODO: Update to utilize BNNS broadcasting
|
if (out.dtype() != float32) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[matmul_cblas] on CPU currently only supports float32");
|
||||||
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
return matmul_cblas_general(a_pre, b_pre, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void matmul_bnns_general(
|
||||||
|
const array& a_pre,
|
||||||
|
const array& b_pre,
|
||||||
|
array& out,
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f) {
|
||||||
|
// TODO: Update to utilize BNNS broadcasting
|
||||||
|
|
||||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||||
@@ -72,11 +97,19 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
|||||||
size_t N = b.shape(-1);
|
size_t N = b.shape(-1);
|
||||||
size_t K = a.shape(-1);
|
size_t K = a.shape(-1);
|
||||||
|
|
||||||
|
if (M == 0 || N == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (K == 0) {
|
||||||
|
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||||
|
|
||||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||||
/* float alpha = */ 1.0,
|
/* float alpha = */ alpha,
|
||||||
/* float beta = */ 0.0,
|
/* float beta = */ beta,
|
||||||
/* bool transA = */ a_transposed,
|
/* bool transA = */ a_transposed,
|
||||||
/* bool transB = */ b_transposed,
|
/* bool transB = */ b_transposed,
|
||||||
/* bool quadratic = */ false,
|
/* bool quadratic = */ false,
|
||||||
@@ -157,6 +190,12 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
|||||||
BNNSFilterDestroy(bnns_filter);
|
BNNSFilterDestroy(bnns_filter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||||
|
// TODO: Update to utilize BNNS broadcasting
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
return matmul_bnns_general(a_pre, b_pre, out);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -166,4 +205,16 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return matmul_bnns(inputs[0], inputs[1], out);
|
return matmul_bnns(inputs[0], inputs[1], out);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
// Fill output with C
|
||||||
|
auto& c = inputs[2];
|
||||||
|
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||||
|
copy(c, out, ctype);
|
||||||
|
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||||
|
}
|
||||||
|
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@@ -17,6 +17,12 @@
|
|||||||
primitive::eval(inputs, out); \
|
primitive::eval(inputs, out); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DEFAULT_MULTI(primitive) \
|
||||||
|
void primitive::eval_cpu( \
|
||||||
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
|
primitive::eval(inputs, outputs); \
|
||||||
|
}
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Use the default implementation for the following primitives
|
// Use the default implementation for the following primitives
|
||||||
@@ -26,12 +32,18 @@ DEFAULT(ArgReduce)
|
|||||||
DEFAULT(ArgSort)
|
DEFAULT(ArgSort)
|
||||||
DEFAULT(AsStrided)
|
DEFAULT(AsStrided)
|
||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
|
DEFAULT(Ceil)
|
||||||
|
DEFAULT_MULTI(Compiled)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
|
DEFAULT_MULTI(CustomVJP)
|
||||||
|
DEFAULT_MULTI(Depends)
|
||||||
|
DEFAULT_MULTI(DivMod)
|
||||||
DEFAULT(Equal)
|
DEFAULT(Equal)
|
||||||
DEFAULT(Erf)
|
DEFAULT(Erf)
|
||||||
DEFAULT(ErfInv)
|
DEFAULT(ErfInv)
|
||||||
DEFAULT(FFT)
|
DEFAULT(FFT)
|
||||||
|
DEFAULT(Floor)
|
||||||
DEFAULT(Gather)
|
DEFAULT(Gather)
|
||||||
DEFAULT(Greater)
|
DEFAULT(Greater)
|
||||||
DEFAULT(GreaterEqual)
|
DEFAULT(GreaterEqual)
|
||||||
@@ -39,16 +51,24 @@ DEFAULT(Less)
|
|||||||
DEFAULT(LessEqual)
|
DEFAULT(LessEqual)
|
||||||
DEFAULT(Load)
|
DEFAULT(Load)
|
||||||
DEFAULT(LogicalNot)
|
DEFAULT(LogicalNot)
|
||||||
|
DEFAULT(LogicalAnd)
|
||||||
|
DEFAULT(LogicalOr)
|
||||||
DEFAULT(LogAddExp)
|
DEFAULT(LogAddExp)
|
||||||
|
DEFAULT(Maximum)
|
||||||
|
DEFAULT(Minimum)
|
||||||
DEFAULT(NotEqual)
|
DEFAULT(NotEqual)
|
||||||
DEFAULT(Pad)
|
DEFAULT(Pad)
|
||||||
DEFAULT(Partition)
|
DEFAULT(Partition)
|
||||||
|
DEFAULT_MULTI(QRF)
|
||||||
DEFAULT(RandomBits)
|
DEFAULT(RandomBits)
|
||||||
DEFAULT(Reshape)
|
DEFAULT(Reshape)
|
||||||
|
DEFAULT(Remainder)
|
||||||
|
DEFAULT(Round)
|
||||||
DEFAULT(Scatter)
|
DEFAULT(Scatter)
|
||||||
DEFAULT(Sigmoid)
|
DEFAULT(Sigmoid)
|
||||||
DEFAULT(Sign)
|
DEFAULT(Sign)
|
||||||
DEFAULT(Slice)
|
DEFAULT(Slice)
|
||||||
|
DEFAULT_MULTI(Split)
|
||||||
DEFAULT(Sort)
|
DEFAULT(Sort)
|
||||||
DEFAULT(StopGradient)
|
DEFAULT(StopGradient)
|
||||||
DEFAULT(Transpose)
|
DEFAULT(Transpose)
|
||||||
@@ -57,21 +77,11 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||||
auto size = in.data_size();
|
set_unary_output_data(in, out);
|
||||||
out.set_data(
|
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
|
|
||||||
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
||||||
auto size = in.data_size();
|
set_unary_output_data(in, out);
|
||||||
out.set_data(
|
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
|
|
||||||
} else if (is_unsigned(in.dtype())) {
|
} else if (is_unsigned(in.dtype())) {
|
||||||
// No-op for unsigned types
|
// No-op for unsigned types
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
@@ -124,12 +134,8 @@ void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvacosf(out.data<float>(), in.data<float>(), &size);
|
vvacosf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -140,12 +146,8 @@ void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -156,12 +158,8 @@ void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvasinf(out.data<float>(), in.data<float>(), &size);
|
vvasinf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -172,12 +170,8 @@ void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -188,12 +182,8 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvatanf(out.data<float>(), in.data<float>(), &size);
|
vvatanf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -204,12 +194,8 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -221,30 +207,23 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
if (in.flags().contiguous) {
|
if (in.flags().contiguous) {
|
||||||
auto allocfn = [&in, &out]() {
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
|
||||||
in.data_size(),
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
};
|
|
||||||
// Use accelerate functions if possible
|
// Use accelerate functions if possible
|
||||||
if (in.dtype() == float32 && out.dtype() == uint32) {
|
if (in.dtype() == float32 && out.dtype() == uint32) {
|
||||||
allocfn();
|
set_unary_output_data(in, out);
|
||||||
vDSP_vfixu32(
|
vDSP_vfixu32(
|
||||||
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
||||||
return;
|
return;
|
||||||
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
||||||
allocfn();
|
set_unary_output_data(in, out);
|
||||||
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
||||||
return;
|
return;
|
||||||
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
||||||
allocfn();
|
set_unary_output_data(in, out);
|
||||||
vDSP_vfltu32(
|
vDSP_vfltu32(
|
||||||
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||||
return;
|
return;
|
||||||
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
||||||
allocfn();
|
set_unary_output_data(in, out);
|
||||||
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -256,12 +235,8 @@ void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvcosf(out.data<float>(), in.data<float>(), &size);
|
vvcosf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -272,12 +247,8 @@ void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -326,12 +297,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
} else if (is_floating_point(out.dtype())) {
|
} else if (is_floating_point(out.dtype())) {
|
||||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||||
@@ -358,12 +325,8 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
switch (base_) {
|
switch (base_) {
|
||||||
case Base::e:
|
case Base::e:
|
||||||
vvlogf(
|
vvlogf(
|
||||||
@@ -387,12 +350,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvlog1pf(
|
vvlog1pf(
|
||||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
} else if (is_floating_point(out.dtype())) {
|
} else if (is_floating_point(out.dtype())) {
|
||||||
@@ -404,47 +363,6 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
if (out.dtype() == float32) {
|
|
||||||
binary(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
[](auto x, auto y) { return (x > y) ? x : y; },
|
|
||||||
UseDefaultBinaryOp(),
|
|
||||||
UseDefaultBinaryOp(),
|
|
||||||
[](const auto* a, const auto* b, auto* out, int n) {
|
|
||||||
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
|
|
||||||
if (out.dtype() == float32) {
|
|
||||||
binary(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
[](auto x, auto y) { return (x < y) ? x : y; },
|
|
||||||
UseDefaultBinaryOp(),
|
|
||||||
UseDefaultBinaryOp(),
|
|
||||||
[](const auto* a, const auto* b, auto* out, int n) {
|
|
||||||
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
@@ -474,13 +392,8 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||||
auto size = in.data_size();
|
set_unary_output_data(in, out);
|
||||||
out.set_data(
|
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
|
|
||||||
} else {
|
} else {
|
||||||
unary(in, out, [](auto x) { return -x; });
|
unary(in, out, [](auto x) { return -x; });
|
||||||
}
|
}
|
||||||
@@ -493,8 +406,14 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||||
b.flags().row_contiguous) {
|
b.flags().row_contiguous) {
|
||||||
int size = a.size();
|
int size = a.size();
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||||
vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
out.copy_shared_buffer(a);
|
||||||
|
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||||
|
out.copy_shared_buffer(b);
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
@@ -535,12 +454,8 @@ void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvsinf(out.data<float>(), in.data<float>(), &size);
|
vvsinf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -551,12 +466,8 @@ void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -567,12 +478,8 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||||
} else {
|
} else {
|
||||||
unary(in, out, [](auto x) { return x * x; });
|
unary(in, out, [](auto x) { return x * x; });
|
||||||
@@ -583,12 +490,8 @@ void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
if (recip_) {
|
if (recip_) {
|
||||||
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
@@ -643,12 +546,8 @@ void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvtanf(out.data<float>(), in.data<float>(), &size);
|
vvtanf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -659,12 +558,8 @@ void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||||
|
set_unary_output_data(in, out);
|
||||||
int size = in.data_size();
|
int size = in.data_size();
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
|
103
mlx/backend/accelerate/quantized.cpp
Normal file
103
mlx/backend/accelerate/quantized.cpp
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <simd/vector.h>
|
||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void _qmm_t_4_64(
|
||||||
|
float* result,
|
||||||
|
const float* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const float* scales,
|
||||||
|
const float* biases,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr int bits = 4;
|
||||||
|
constexpr int group_size = 64;
|
||||||
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
|
constexpr int pack_factor = 32 / bits;
|
||||||
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
const int Kg = K / group_size;
|
||||||
|
const int Kw = K / pack_factor;
|
||||||
|
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
const uint32_t* w_local = w;
|
||||||
|
const float* scales_local = scales;
|
||||||
|
const float* biases_local = biases;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
const simd_float16* x_local = (simd_float16*)x;
|
||||||
|
simd_float16 sum = 0;
|
||||||
|
for (int k = 0; k < K; k += group_size) {
|
||||||
|
float scale = *scales_local++;
|
||||||
|
float bias = *biases_local++;
|
||||||
|
|
||||||
|
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||||
|
// TODO: vectorize this properly
|
||||||
|
simd_uint16 wi;
|
||||||
|
for (int e = 0; e < 2; e++) {
|
||||||
|
uint32_t wii = *w_local++;
|
||||||
|
for (int p = 0; p < 8; p++) {
|
||||||
|
wi[e * 8 + p] = wii & bitmask;
|
||||||
|
wii >>= bits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
simd_float16 wf = simd_float(wi);
|
||||||
|
wf *= scale;
|
||||||
|
wf += bias;
|
||||||
|
|
||||||
|
sum += (*x_local) * wf;
|
||||||
|
x_local++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*result = simd_reduce_add(sum);
|
||||||
|
result++;
|
||||||
|
}
|
||||||
|
|
||||||
|
x += K;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 4);
|
||||||
|
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& w = inputs[1];
|
||||||
|
auto& scales = inputs[2];
|
||||||
|
auto& biases = inputs[3];
|
||||||
|
|
||||||
|
bool condition =
|
||||||
|
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
|
||||||
|
scales.flags().row_contiguous && biases.flags().row_contiguous &&
|
||||||
|
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
|
||||||
|
|
||||||
|
if (condition) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
int K = x.shape(-1);
|
||||||
|
int M = x.size() / K;
|
||||||
|
int N = out.shape(-1);
|
||||||
|
_qmm_t_4_64(
|
||||||
|
out.data<float>(),
|
||||||
|
x.data<float>(),
|
||||||
|
w.data<uint32_t>(),
|
||||||
|
scales.data<float>(),
|
||||||
|
biases.data<float>(),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Make sure that the last dimension is contiguous
|
// Make sure that the last dimension is contiguous
|
||||||
auto check_input = [](array x) {
|
auto check_input = [](array x) {
|
||||||
if (x.strides()[x.ndim() - 1] == 1) {
|
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||||
|
if (x.ndim() > 1) {
|
||||||
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
|
}
|
||||||
|
if (no_copy) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
@@ -3,16 +3,20 @@ target_sources(
|
|||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||||
)
|
)
|
||||||
|
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/common/binary_two.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@@ -75,6 +76,61 @@ void Add::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DivMod::eval(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto integral_op = [](auto x, auto y) {
|
||||||
|
return std::make_pair(x / y, x % y);
|
||||||
|
};
|
||||||
|
auto float_op = [](auto x, auto y) {
|
||||||
|
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||||
|
};
|
||||||
|
switch (outputs[0].dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool>(a, b, outputs, integral_op);
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t>(a, b, outputs, integral_op);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t>(a, b, outputs, integral_op);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t>(a, b, outputs, integral_op);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t>(a, b, outputs, integral_op);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t>(a, b, outputs, integral_op);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t>(a, b, outputs, integral_op);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t>(a, b, outputs, integral_op);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t>(a, b, outputs, integral_op);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t>(a, b, outputs, float_op);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float>(a, b, outputs, float_op);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
// Should never get here
|
||||||
|
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
@@ -82,6 +138,47 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct RemainderFn {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
return numerator % denominator;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
auto r = numerator % denominator;
|
||||||
|
if (r != 0 && (r < 0 != denominator < 0))
|
||||||
|
r += denominator;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
auto r = std::fmod(numerator, denominator);
|
||||||
|
if (r != 0 && (r < 0 != denominator < 0)) {
|
||||||
|
r += denominator;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
|
||||||
|
return numerator % denominator;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, RemainderFn{});
|
||||||
|
}
|
||||||
|
|
||||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
if (equal_nan_) {
|
if (equal_nan_) {
|
||||||
@@ -154,14 +251,33 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
binary(a, b, out, [](auto x, auto y) {
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return (x > y) ? x : y;
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
if (is_floating_point(out.dtype())) {
|
||||||
|
binary(a, b, out, [](auto x, auto y) {
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return (x < y) ? x : y;
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
@@ -40,29 +39,83 @@ void set_binary_op_output_data(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
array& out,
|
array& out,
|
||||||
BinaryOpType bopt) {
|
BinaryOpType bopt,
|
||||||
|
bool donate_with_move = false) {
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case ScalarScalar:
|
case ScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
break;
|
break;
|
||||||
case ScalarVector:
|
case ScalarVector:
|
||||||
out.set_data(
|
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
if (donate_with_move) {
|
||||||
b.data_size(),
|
out.move_shared_buffer(b);
|
||||||
b.strides(),
|
} else {
|
||||||
b.flags());
|
out.copy_shared_buffer(b);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||||
|
b.data_size(),
|
||||||
|
b.strides(),
|
||||||
|
b.flags());
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case VectorScalar:
|
case VectorScalar:
|
||||||
|
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||||
|
if (donate_with_move) {
|
||||||
|
out.move_shared_buffer(a);
|
||||||
|
} else {
|
||||||
|
out.copy_shared_buffer(a);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||||
|
a.data_size(),
|
||||||
|
a.strides(),
|
||||||
|
a.flags());
|
||||||
|
}
|
||||||
|
break;
|
||||||
case VectorVector:
|
case VectorVector:
|
||||||
out.set_data(
|
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
if (donate_with_move) {
|
||||||
a.data_size(),
|
out.move_shared_buffer(a);
|
||||||
a.strides(),
|
} else {
|
||||||
a.flags());
|
out.copy_shared_buffer(a);
|
||||||
|
}
|
||||||
|
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||||
|
if (donate_with_move) {
|
||||||
|
out.move_shared_buffer(b);
|
||||||
|
} else {
|
||||||
|
out.copy_shared_buffer(b);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||||
|
a.data_size(),
|
||||||
|
a.strides(),
|
||||||
|
a.flags());
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case General:
|
case General:
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
if (a.is_donatable() && a.flags().row_contiguous &&
|
||||||
|
a.itemsize() == out.itemsize() && a.size() == out.size()) {
|
||||||
|
if (donate_with_move) {
|
||||||
|
out.move_shared_buffer(a);
|
||||||
|
} else {
|
||||||
|
out.copy_shared_buffer(a);
|
||||||
|
}
|
||||||
|
} else if (
|
||||||
|
b.is_donatable() && b.flags().row_contiguous &&
|
||||||
|
b.itemsize() == out.itemsize() && b.size() == out.size()) {
|
||||||
|
if (donate_with_move) {
|
||||||
|
out.move_shared_buffer(b);
|
||||||
|
} else {
|
||||||
|
out.copy_shared_buffer(b);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -73,6 +126,12 @@ struct UseDefaultBinaryOp {
|
|||||||
// Should we throw? This should normally never be called.
|
// Should we throw? This should normally never be called.
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||||
|
// Should we throw? This should normally never be called.
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
@@ -89,6 +148,18 @@ struct DefaultVectorScalar {
|
|||||||
a++;
|
a++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||||
|
T scalar = *b;
|
||||||
|
while (size-- > 0) {
|
||||||
|
auto dst = op(*a, scalar);
|
||||||
|
*dst_a = dst.first;
|
||||||
|
*dst_b = dst.second;
|
||||||
|
dst_a++;
|
||||||
|
dst_b++;
|
||||||
|
a++;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
@@ -105,6 +176,18 @@ struct DefaultScalarVector {
|
|||||||
b++;
|
b++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||||
|
T scalar = *a;
|
||||||
|
while (size-- > 0) {
|
||||||
|
auto dst = op(scalar, *b);
|
||||||
|
*dst_a = dst.first;
|
||||||
|
*dst_b = dst.second;
|
||||||
|
dst_a++;
|
||||||
|
dst_b++;
|
||||||
|
b++;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
@@ -121,6 +204,18 @@ struct DefaultVectorVector {
|
|||||||
b++;
|
b++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||||
|
while (size-- > 0) {
|
||||||
|
auto dst = op(*a, *b);
|
||||||
|
*dst_a = dst.first;
|
||||||
|
*dst_b = dst.second;
|
||||||
|
dst_a++;
|
||||||
|
dst_b++;
|
||||||
|
a++;
|
||||||
|
b++;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
|
536
mlx/backend/common/binary_two.h
Normal file
536
mlx/backend/common/binary_two.h
Normal file
@@ -0,0 +1,536 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims1(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
for (size_t i = 0; i < out_a.size(); ++i) {
|
||||||
|
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
dst_a[i] = dst.first;
|
||||||
|
dst_b[i] = dst.second;
|
||||||
|
a_idx += a.strides()[0];
|
||||||
|
b_idx += b.strides()[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims1(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op,
|
||||||
|
int stride) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||||
|
a_idx += a.strides()[0];
|
||||||
|
b_idx += b.strides()[0];
|
||||||
|
dst_a += stride;
|
||||||
|
dst_b += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims2(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
dst_a[out_idx] = dst.first;
|
||||||
|
dst_b[out_idx++] = dst.second;
|
||||||
|
a_idx += a.strides()[1];
|
||||||
|
b_idx += b.strides()[1];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims2(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op,
|
||||||
|
int stride) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||||
|
a_idx += a.strides()[1];
|
||||||
|
b_idx += b.strides()[1];
|
||||||
|
dst_a += stride;
|
||||||
|
dst_b += stride;
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims3(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
dst_a[out_idx] = dst.first;
|
||||||
|
dst_b[out_idx++] = dst.second;
|
||||||
|
a_idx += a.strides()[2];
|
||||||
|
b_idx += b.strides()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims4(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||||
|
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
dst_a[out_idx] = dst.first;
|
||||||
|
dst_b[out_idx++] = dst.second;
|
||||||
|
a_idx += a.strides()[3];
|
||||||
|
b_idx += b.strides()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||||
|
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dispatch_dims(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op) {
|
||||||
|
switch (out_a.ndim()) {
|
||||||
|
case 1:
|
||||||
|
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
|
||||||
|
return;
|
||||||
|
case 2:
|
||||||
|
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
|
||||||
|
return;
|
||||||
|
case 3:
|
||||||
|
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
|
||||||
|
return;
|
||||||
|
case 4:
|
||||||
|
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
for (size_t i = 0; i < out_a.size(); i++) {
|
||||||
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
|
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||||
|
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dispatch_dims(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op,
|
||||||
|
int dim,
|
||||||
|
int stride) {
|
||||||
|
// Number of dimensions to loop over for vectorized ops
|
||||||
|
switch (dim) {
|
||||||
|
case 1:
|
||||||
|
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||||
|
return;
|
||||||
|
case 2:
|
||||||
|
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
for (size_t i = 0; i < out_a.size(); i += stride) {
|
||||||
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
|
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||||
|
dst_a += stride;
|
||||||
|
dst_b += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
typename OpSV,
|
||||||
|
typename OpVS,
|
||||||
|
typename OpVV>
|
||||||
|
void binary_op(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op,
|
||||||
|
OpSV opsv,
|
||||||
|
OpVS opvs,
|
||||||
|
OpVV opvv) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out_a, bopt);
|
||||||
|
set_binary_op_output_data(a, b, out_b, bopt);
|
||||||
|
|
||||||
|
// The full computation is scalar scalar so call the base op once
|
||||||
|
if (bopt == ScalarScalar) {
|
||||||
|
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
|
||||||
|
op(*a.data<T>(), *b.data<T>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The full computation is scalar vector so delegate to the op
|
||||||
|
if (bopt == ScalarVector) {
|
||||||
|
opsv(
|
||||||
|
a.data<T>(),
|
||||||
|
b.data<T>(),
|
||||||
|
out_a.data<U>(),
|
||||||
|
out_b.data<U>(),
|
||||||
|
b.data_size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The full computation is vector scalar so delegate to the op
|
||||||
|
if (bopt == VectorScalar) {
|
||||||
|
opvs(
|
||||||
|
a.data<T>(),
|
||||||
|
b.data<T>(),
|
||||||
|
out_a.data<U>(),
|
||||||
|
out_b.data<U>(),
|
||||||
|
a.data_size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The full computation is vector vector so delegate to the op
|
||||||
|
if (bopt == VectorVector) {
|
||||||
|
opvv(
|
||||||
|
a.data<T>(),
|
||||||
|
b.data<T>(),
|
||||||
|
out_a.data<U>(),
|
||||||
|
out_b.data<U>(),
|
||||||
|
out_a.size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// General computation so let's try to optimize
|
||||||
|
|
||||||
|
// Get the left-most dim such that the array is row contiguous after
|
||||||
|
auto& strides = out_a.strides();
|
||||||
|
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||||
|
int d = arr.ndim() - 1;
|
||||||
|
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||||
|
}
|
||||||
|
return d + 1;
|
||||||
|
};
|
||||||
|
auto a_rc_dim = leftmost_rc_dim(a);
|
||||||
|
auto b_rc_dim = leftmost_rc_dim(b);
|
||||||
|
|
||||||
|
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||||
|
auto leftmost_s_dim = [](const array& arr) {
|
||||||
|
int d = arr.ndim() - 1;
|
||||||
|
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||||
|
}
|
||||||
|
return d + 1;
|
||||||
|
};
|
||||||
|
auto a_s_dim = leftmost_s_dim(a);
|
||||||
|
auto b_s_dim = leftmost_s_dim(b);
|
||||||
|
|
||||||
|
auto ndim = out_a.ndim();
|
||||||
|
|
||||||
|
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||||
|
int dim = ndim;
|
||||||
|
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||||
|
bopt = VectorVector;
|
||||||
|
dim = d;
|
||||||
|
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||||
|
// contiguous
|
||||||
|
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||||
|
bopt = VectorScalar;
|
||||||
|
dim = d;
|
||||||
|
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||||
|
// contiguous
|
||||||
|
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||||
|
bopt = ScalarVector;
|
||||||
|
dim = d;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||||
|
// contiguous methods above. Except for the case that the flags do not
|
||||||
|
// correspond to the underlying contiguity.
|
||||||
|
size_t stride;
|
||||||
|
if (dim == 0 || strides[dim - 1] < 16) {
|
||||||
|
stride = 1;
|
||||||
|
bopt = General;
|
||||||
|
dim = ndim;
|
||||||
|
} else {
|
||||||
|
stride = strides[dim - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (bopt) {
|
||||||
|
case VectorVector:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
|
||||||
|
break;
|
||||||
|
case VectorScalar:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
|
||||||
|
break;
|
||||||
|
case ScalarVector:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||||
|
void binary_op(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
Op op,
|
||||||
|
OpSV opsv,
|
||||||
|
OpVS opvs,
|
||||||
|
OpVV opvv) {
|
||||||
|
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||||
|
// with template specializations and overloading. Would it be simpler?
|
||||||
|
|
||||||
|
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||||
|
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||||
|
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// opsv and opvs were UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
opvv);
|
||||||
|
}
|
||||||
|
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// opsv and opvv were UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
opvs,
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// opsv was UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
opvs,
|
||||||
|
opvv);
|
||||||
|
}
|
||||||
|
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||||
|
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// opvs and opvv were UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
opsv,
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// opvs was UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
opsv,
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
opvv);
|
||||||
|
}
|
||||||
|
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// opvv was UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
opsv,
|
||||||
|
opvs,
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// All ops provided
|
||||||
|
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
void binary_op(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
Op op) {
|
||||||
|
DefaultScalarVector<T, T, Op> opsv(op);
|
||||||
|
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||||
|
DefaultVectorVector<T, T, Op> opvv(op);
|
||||||
|
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Ops>
|
||||||
|
void binary(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
Ops... ops) {
|
||||||
|
switch (outputs[0].dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t>(a, b, outputs, ops...);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
59
mlx/backend/common/compiled.cpp
Normal file
59
mlx/backend/common/compiled.cpp
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Build the real tape
|
||||||
|
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
|
||||||
|
const std::vector<array>& trace_tape,
|
||||||
|
const std::vector<array>& trace_inputs,
|
||||||
|
const std::vector<array>& trace_outputs,
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||||
|
}
|
||||||
|
std::queue<array> tape;
|
||||||
|
for (auto& a : trace_tape) {
|
||||||
|
// Find real inputs
|
||||||
|
std::vector<array> real_inputs;
|
||||||
|
for (auto& in : a.inputs()) {
|
||||||
|
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||||
|
}
|
||||||
|
tape.push(
|
||||||
|
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
|
||||||
|
trace_to_real.insert({a.id(), tape.back()});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> outputs;
|
||||||
|
for (auto& o : trace_outputs) {
|
||||||
|
outputs.push_back(trace_to_real.at(o.id()));
|
||||||
|
}
|
||||||
|
return {tape, outputs};
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compiled::eval(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Make the a real tape from the tracers
|
||||||
|
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
|
||||||
|
|
||||||
|
// Run the tape
|
||||||
|
while (!tape.empty()) {
|
||||||
|
auto a = std::move(tape.front());
|
||||||
|
tape.pop();
|
||||||
|
auto outputs = a.outputs();
|
||||||
|
a.primitive().eval_cpu(a.inputs(), outputs);
|
||||||
|
a.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy results into outputs
|
||||||
|
for (int o = 0; o < real_outputs.size(); ++o) {
|
||||||
|
outputs[o].copy_shared_buffer(real_outputs[o]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@@ -3,7 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#ifdef ACCELERATE_NEW_LAPACK
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
#include <vecLib/cblas_new.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
#else
|
#else
|
||||||
#include <cblas.h>
|
#include <cblas.h>
|
||||||
#endif
|
#endif
|
||||||
@@ -357,7 +357,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peform gemm
|
// Perform gemm
|
||||||
cblas_sgemm(
|
cblas_sgemm(
|
||||||
CblasRowMajor,
|
CblasRowMajor,
|
||||||
CblasNoTrans, // no trans A
|
CblasNoTrans, // no trans A
|
||||||
@@ -459,7 +459,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peform gemm
|
// Perform gemm
|
||||||
cblas_sgemm(
|
cblas_sgemm(
|
||||||
CblasRowMajor,
|
CblasRowMajor,
|
||||||
CblasNoTrans, // no trans A
|
CblasNoTrans, // no trans A
|
||||||
|
@@ -289,11 +289,16 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
|||||||
// Allocate the output
|
// Allocate the output
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::Vector:
|
case CopyType::Vector:
|
||||||
dst.set_data(
|
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
|
||||||
allocator::malloc_or_wait(src.data_size() * dst.itemsize()),
|
dst.copy_shared_buffer(src);
|
||||||
src.data_size(),
|
} else {
|
||||||
src.strides(),
|
auto size = src.data_size();
|
||||||
src.flags());
|
dst.set_data(
|
||||||
|
allocator::malloc_or_wait(size * dst.itemsize()),
|
||||||
|
size,
|
||||||
|
src.strides(),
|
||||||
|
src.flags());
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
|
@@ -1,6 +1,12 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
|
#else
|
||||||
#include <cblas.h>
|
#include <cblas.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
@@ -12,6 +18,12 @@
|
|||||||
primitive::eval(inputs, out); \
|
primitive::eval(inputs, out); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DEFAULT_MULTI(primitive) \
|
||||||
|
void primitive::eval_cpu( \
|
||||||
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
|
primitive::eval(inputs, outputs); \
|
||||||
|
}
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
DEFAULT(Abs)
|
DEFAULT(Abs)
|
||||||
@@ -29,17 +41,24 @@ DEFAULT(ArgSort)
|
|||||||
DEFAULT(AsType)
|
DEFAULT(AsType)
|
||||||
DEFAULT(AsStrided)
|
DEFAULT(AsStrided)
|
||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
|
DEFAULT_MULTI(DivMod)
|
||||||
|
DEFAULT(Ceil)
|
||||||
|
DEFAULT_MULTI(Compiled)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
DEFAULT(Convolution)
|
DEFAULT(Convolution)
|
||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
DEFAULT(Cos)
|
DEFAULT(Cos)
|
||||||
DEFAULT(Cosh)
|
DEFAULT(Cosh)
|
||||||
|
DEFAULT_MULTI(CustomVJP)
|
||||||
|
DEFAULT_MULTI(Depends)
|
||||||
DEFAULT(Divide)
|
DEFAULT(Divide)
|
||||||
|
DEFAULT(Remainder)
|
||||||
DEFAULT(Equal)
|
DEFAULT(Equal)
|
||||||
DEFAULT(Erf)
|
DEFAULT(Erf)
|
||||||
DEFAULT(ErfInv)
|
DEFAULT(ErfInv)
|
||||||
DEFAULT(Exp)
|
DEFAULT(Exp)
|
||||||
DEFAULT(FFT)
|
DEFAULT(FFT)
|
||||||
|
DEFAULT(Floor)
|
||||||
DEFAULT(Full)
|
DEFAULT(Full)
|
||||||
DEFAULT(Gather)
|
DEFAULT(Gather)
|
||||||
DEFAULT(Greater)
|
DEFAULT(Greater)
|
||||||
@@ -50,6 +69,8 @@ DEFAULT(Load)
|
|||||||
DEFAULT(Log)
|
DEFAULT(Log)
|
||||||
DEFAULT(Log1p)
|
DEFAULT(Log1p)
|
||||||
DEFAULT(LogicalNot)
|
DEFAULT(LogicalNot)
|
||||||
|
DEFAULT(LogicalAnd)
|
||||||
|
DEFAULT(LogicalOr)
|
||||||
DEFAULT(LogAddExp)
|
DEFAULT(LogAddExp)
|
||||||
DEFAULT(Maximum)
|
DEFAULT(Maximum)
|
||||||
DEFAULT(Minimum)
|
DEFAULT(Minimum)
|
||||||
@@ -59,9 +80,12 @@ DEFAULT(NotEqual)
|
|||||||
DEFAULT(Pad)
|
DEFAULT(Pad)
|
||||||
DEFAULT(Partition)
|
DEFAULT(Partition)
|
||||||
DEFAULT(Power)
|
DEFAULT(Power)
|
||||||
|
DEFAULT_MULTI(QRF)
|
||||||
|
DEFAULT(QuantizedMatmul)
|
||||||
DEFAULT(RandomBits)
|
DEFAULT(RandomBits)
|
||||||
DEFAULT(Reduce)
|
DEFAULT(Reduce)
|
||||||
DEFAULT(Reshape)
|
DEFAULT(Reshape)
|
||||||
|
DEFAULT(Round)
|
||||||
DEFAULT(Scan)
|
DEFAULT(Scan)
|
||||||
DEFAULT(Scatter)
|
DEFAULT(Scatter)
|
||||||
DEFAULT(Sigmoid)
|
DEFAULT(Sigmoid)
|
||||||
@@ -71,6 +95,7 @@ DEFAULT(Sinh)
|
|||||||
DEFAULT(Slice)
|
DEFAULT(Slice)
|
||||||
DEFAULT(Softmax)
|
DEFAULT(Softmax)
|
||||||
DEFAULT(Sort)
|
DEFAULT(Sort)
|
||||||
|
DEFAULT_MULTI(Split)
|
||||||
DEFAULT(Square)
|
DEFAULT(Square)
|
||||||
DEFAULT(Sqrt)
|
DEFAULT(Sqrt)
|
||||||
DEFAULT(StopGradient)
|
DEFAULT(StopGradient)
|
||||||
@@ -79,16 +104,14 @@ DEFAULT(Tan)
|
|||||||
DEFAULT(Tanh)
|
DEFAULT(Tanh)
|
||||||
DEFAULT(Transpose)
|
DEFAULT(Transpose)
|
||||||
|
|
||||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
namespace {
|
||||||
if (out.dtype() != float32) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[Matmul::eval_cpu] Currently only supports float32.");
|
|
||||||
}
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
auto& a_pre = inputs[0];
|
|
||||||
auto& b_pre = inputs[1];
|
|
||||||
|
|
||||||
|
inline void matmul_common_general(
|
||||||
|
const array& a_pre,
|
||||||
|
const array& b_pre,
|
||||||
|
array& out,
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f) {
|
||||||
auto check_transpose = [](const array& arr) {
|
auto check_transpose = [](const array& arr) {
|
||||||
auto stx = arr.strides()[arr.ndim() - 2];
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
auto sty = arr.strides()[arr.ndim() - 1];
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
@@ -106,9 +129,17 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||||
int M = a.shape(-2);
|
size_t M = a.shape(-2);
|
||||||
int N = b.shape(-1);
|
size_t N = b.shape(-1);
|
||||||
int K = a.shape(-1);
|
size_t K = a.shape(-1);
|
||||||
|
if (M == 0 || N == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (K == 0) {
|
||||||
|
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||||
cblas_sgemm(
|
cblas_sgemm(
|
||||||
CblasRowMajor,
|
CblasRowMajor,
|
||||||
@@ -117,16 +148,41 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
1.0f, // alpha
|
alpha, // alpha
|
||||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||||
lda,
|
lda,
|
||||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||||
ldb,
|
ldb,
|
||||||
0.0f, // beta
|
beta, // beta
|
||||||
out.data<float>() + M * N * i,
|
out.data<float>() + M * N * i,
|
||||||
out.shape(-1) // ldc
|
out.shape(-1) // ldc
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Matmul::eval_cpu] Currently only supports float32.");
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
return matmul_common_general(inputs[0], inputs[1], out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill output with C
|
||||||
|
auto& c = inputs[2];
|
||||||
|
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||||
|
copy(c, out, ctype);
|
||||||
|
|
||||||
|
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -5,7 +5,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/load.h"
|
#include "mlx/io/load.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -13,7 +13,7 @@ namespace mlx::core {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <const uint8_t scalar_size>
|
template <const uint8_t scalar_size>
|
||||||
void swap_endianess(uint8_t* data_bytes, size_t N) {
|
void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||||
struct Elem {
|
struct Elem {
|
||||||
uint8_t bytes[scalar_size];
|
uint8_t bytes[scalar_size];
|
||||||
};
|
};
|
||||||
@@ -39,13 +39,13 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
if (swap_endianness_) {
|
if (swap_endianness_) {
|
||||||
switch (out.itemsize()) {
|
switch (out.itemsize()) {
|
||||||
case 2:
|
case 2:
|
||||||
swap_endianess<2>(out.data<uint8_t>(), out.data_size());
|
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
swap_endianess<4>(out.data<uint8_t>(), out.data_size());
|
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
swap_endianess<8>(out.data<uint8_t>(), out.data_size());
|
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/arange.h"
|
#include "mlx/backend/common/arange.h"
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/erf.h"
|
#include "mlx/backend/common/erf.h"
|
||||||
#include "mlx/backend/common/threefry.h"
|
#include "mlx/backend/common/threefry.h"
|
||||||
@@ -167,6 +168,17 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (not is_integral(in.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::ceil(x); });
|
||||||
|
} else {
|
||||||
|
// No-op integer types
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||||
std::vector<int> sizes;
|
std::vector<int> sizes;
|
||||||
sizes.push_back(0);
|
sizes.push_back(0);
|
||||||
@@ -220,22 +232,38 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CustomVJP::eval(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() > outputs.size());
|
||||||
|
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||||
|
i++, j++) {
|
||||||
|
outputs[i].copy_shared_buffer(inputs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Depends::eval(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() > outputs.size());
|
||||||
|
for (int i = 0; i < outputs.size(); i++) {
|
||||||
|
outputs[i].copy_shared_buffer(inputs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
switch (out.dtype()) {
|
switch (out.dtype()) {
|
||||||
case float32:
|
case float32:
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
|
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
unary_op<float16_t>(in, out, [](auto x) {
|
unary_op<float16_t>(in, out, [](auto x) {
|
||||||
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
|
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
unary_op<bfloat16_t>(in, out, [](auto x) {
|
unary_op<bfloat16_t>(in, out, [](auto x) {
|
||||||
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
|
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
|
||||||
});
|
});
|
||||||
@@ -252,17 +280,14 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
switch (out.dtype()) {
|
switch (out.dtype()) {
|
||||||
case float32:
|
case float32:
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
|
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
unary_op<float16_t>(in, out, [](auto x) {
|
unary_op<float16_t>(in, out, [](auto x) {
|
||||||
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
|
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
unary_op<bfloat16_t>(in, out, [](auto x) {
|
unary_op<bfloat16_t>(in, out, [](auto x) {
|
||||||
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
|
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
|
||||||
});
|
});
|
||||||
@@ -287,6 +312,17 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (not is_integral(in.dtype())) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::floor(x); });
|
||||||
|
} else {
|
||||||
|
// No-op integer types
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Full::eval(const std::vector<array>& inputs, array& out) {
|
void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
@@ -342,6 +378,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
unary(in, out, [](auto x) { return !x; });
|
unary(in, out, [](auto x) { return !x; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||||
|
auto& in1 = inputs[0];
|
||||||
|
auto& in2 = inputs[1];
|
||||||
|
binary(in1, in2, out, [](auto x, auto y) { return x && y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||||
|
auto& in1 = inputs[0];
|
||||||
|
auto& in2 = inputs[1];
|
||||||
|
binary(in1, in2, out, [](auto x, auto y) { return x || y; });
|
||||||
|
}
|
||||||
|
|
||||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
@@ -444,6 +494,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (not is_integral(in.dtype())) {
|
||||||
|
unary_fp(in, out, RoundOp());
|
||||||
|
} else {
|
||||||
|
// No-op integer types
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
@@ -540,6 +601,58 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
|
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Split::eval(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
auto compute_new_flags = [](const auto& shape,
|
||||||
|
const auto& strides,
|
||||||
|
size_t in_data_size,
|
||||||
|
auto flags) {
|
||||||
|
size_t data_size = 1;
|
||||||
|
size_t f_stride = 1;
|
||||||
|
size_t b_stride = 1;
|
||||||
|
flags.row_contiguous = true;
|
||||||
|
flags.col_contiguous = true;
|
||||||
|
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||||
|
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||||
|
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||||
|
f_stride *= shape[i];
|
||||||
|
b_stride *= shape[ri];
|
||||||
|
if (strides[i] > 0) {
|
||||||
|
data_size *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data_size == 1) {
|
||||||
|
// Broadcasted scalar array is contiguous.
|
||||||
|
flags.contiguous = true;
|
||||||
|
} else if (data_size == in_data_size) {
|
||||||
|
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||||
|
// alone.
|
||||||
|
} else {
|
||||||
|
// We sliced something. So either we are row or col contiguous or we
|
||||||
|
// punched a hole.
|
||||||
|
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::pair<decltype(flags), size_t>{flags, data_size};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<int> indices(1, 0);
|
||||||
|
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||||
|
for (int i = 0; i < indices.size(); i++) {
|
||||||
|
size_t offset = indices[i] * in.strides()[axis_];
|
||||||
|
auto [new_flags, data_size] = compute_new_flags(
|
||||||
|
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||||
|
outputs[i].copy_shared_buffer(
|
||||||
|
in, in.strides(), new_flags, data_size, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Square::eval(const std::vector<array>& inputs, array& out) {
|
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
153
mlx/backend/common/qrf.cpp
Normal file
153
mlx/backend/common/qrf.cpp
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
|
#else
|
||||||
|
#include <lapack.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct lpack;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct lpack<float> {
|
||||||
|
static void xgeqrf(
|
||||||
|
const int* m,
|
||||||
|
const int* n,
|
||||||
|
float* a,
|
||||||
|
const int* lda,
|
||||||
|
float* tau,
|
||||||
|
float* work,
|
||||||
|
const int* lwork,
|
||||||
|
int* info) {
|
||||||
|
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
|
||||||
|
}
|
||||||
|
static void xorgqr(
|
||||||
|
const int* m,
|
||||||
|
const int* n,
|
||||||
|
const int* k,
|
||||||
|
float* a,
|
||||||
|
const int* lda,
|
||||||
|
const float* tau,
|
||||||
|
float* work,
|
||||||
|
const int* lwork,
|
||||||
|
int* info) {
|
||||||
|
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void qrf_impl(const array& a, array& q, array& r) {
|
||||||
|
const int M = a.shape(-2);
|
||||||
|
const int N = a.shape(-1);
|
||||||
|
const int lda = std::max(M, N);
|
||||||
|
size_t num_matrices = a.size() / (M * N);
|
||||||
|
int num_reflectors = std::min(M, N);
|
||||||
|
auto tau =
|
||||||
|
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||||
|
|
||||||
|
// Copy A to inplace input and make it col-contiguous
|
||||||
|
array in(a.shape(), float32, nullptr, {});
|
||||||
|
auto flags = in.flags();
|
||||||
|
|
||||||
|
// Copy the input to be column contiguous
|
||||||
|
flags.col_contiguous = num_matrices == 1;
|
||||||
|
flags.row_contiguous = false;
|
||||||
|
std::vector<size_t> strides = in.strides();
|
||||||
|
strides[in.ndim() - 2] = 1;
|
||||||
|
strides[in.ndim() - 1] = M;
|
||||||
|
in.set_data(
|
||||||
|
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
|
||||||
|
copy_inplace(a, in, CopyType::GeneralGeneral);
|
||||||
|
|
||||||
|
T optimal_work;
|
||||||
|
int lwork = -1;
|
||||||
|
int info;
|
||||||
|
|
||||||
|
// Compute workspace size
|
||||||
|
lpack<T>::xgeqrf(
|
||||||
|
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
||||||
|
|
||||||
|
// Update workspace size
|
||||||
|
lwork = optimal_work;
|
||||||
|
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||||
|
|
||||||
|
// Loop over matrices
|
||||||
|
for (int i = 0; i < num_matrices; ++i) {
|
||||||
|
// Solve
|
||||||
|
lpack<T>::xgeqrf(
|
||||||
|
&M,
|
||||||
|
&N,
|
||||||
|
in.data<float>() + M * N * i,
|
||||||
|
&lda,
|
||||||
|
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||||
|
static_cast<T*>(work.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
}
|
||||||
|
allocator::free(work);
|
||||||
|
|
||||||
|
r.set_data(allocator::malloc_or_wait(r.nbytes()));
|
||||||
|
copy_inplace(in, r, CopyType::General);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_matrices; ++i) {
|
||||||
|
// Zero lower triangle
|
||||||
|
for (int j = 0; j < r.shape(-2); ++j) {
|
||||||
|
for (int k = 0; k < j; ++k) {
|
||||||
|
r.data<T>()[i * N * M + j * N + k] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get work size
|
||||||
|
lwork = -1;
|
||||||
|
lpack<T>::xorgqr(
|
||||||
|
&M,
|
||||||
|
&N,
|
||||||
|
&num_reflectors,
|
||||||
|
nullptr,
|
||||||
|
&lda,
|
||||||
|
nullptr,
|
||||||
|
&optimal_work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
lwork = optimal_work;
|
||||||
|
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||||
|
|
||||||
|
// Loop over matrices
|
||||||
|
for (int i = 0; i < num_matrices; ++i) {
|
||||||
|
// Compute Q
|
||||||
|
lpack<T>::xorgqr(
|
||||||
|
&M,
|
||||||
|
&N,
|
||||||
|
&num_reflectors,
|
||||||
|
in.data<float>() + M * N * i,
|
||||||
|
&lda,
|
||||||
|
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||||
|
static_cast<T*>(work.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
}
|
||||||
|
|
||||||
|
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
||||||
|
copy_inplace(in, q, CopyType::General);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
allocator::free(work);
|
||||||
|
allocator::free(tau);
|
||||||
|
}
|
||||||
|
|
||||||
|
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||||
|
if (!(inputs[0].dtype() == float32)) {
|
||||||
|
throw std::runtime_error("[QRF::eval] only supports float32.");
|
||||||
|
}
|
||||||
|
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
285
mlx/backend/common/quantized.cpp
Normal file
285
mlx/backend/common/quantized.cpp
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, int bits, int group_size>
|
||||||
|
void _qmm(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const T* scales,
|
||||||
|
const T* biases,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
|
constexpr int pack_factor = 32 / bits;
|
||||||
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
const int Ng = N / group_size;
|
||||||
|
const int Nw = N / pack_factor;
|
||||||
|
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
const uint32_t* w_local = w;
|
||||||
|
const T* scales_local = scales;
|
||||||
|
const T* biases_local = biases;
|
||||||
|
|
||||||
|
std::fill(result, result + N, 0);
|
||||||
|
|
||||||
|
for (int k = 0; k < K; k++) {
|
||||||
|
T* result_local = result;
|
||||||
|
T xi = *x++;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; n += group_size) {
|
||||||
|
T scale = *scales_local++;
|
||||||
|
T bias = *biases_local++;
|
||||||
|
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||||
|
uint32_t wi = *w_local++;
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
|
(*result_local++) +=
|
||||||
|
xi * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||||
|
wi >>= bits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result += N;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int bits, int group_size>
|
||||||
|
void _qmm_t(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const T* scales,
|
||||||
|
const T* biases,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
|
constexpr int pack_factor = 32 / bits;
|
||||||
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
const int Kg = K / group_size;
|
||||||
|
const int Kw = K / pack_factor;
|
||||||
|
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
const uint32_t* w_local = w;
|
||||||
|
const T* scales_local = scales;
|
||||||
|
const T* biases_local = biases;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
const T* x_local = x;
|
||||||
|
T sum = 0;
|
||||||
|
for (int k = 0; k < K; k += group_size) {
|
||||||
|
T scale = *scales_local++;
|
||||||
|
T bias = *biases_local++;
|
||||||
|
|
||||||
|
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||||
|
uint32_t wi = *w_local++;
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
|
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||||
|
wi >>= bits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*result = sum;
|
||||||
|
result++;
|
||||||
|
}
|
||||||
|
|
||||||
|
x += K;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void _qmm_dispatch_typed(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const T* scales,
|
||||||
|
const T* biases,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool transposed_w) {
|
||||||
|
switch (bits) {
|
||||||
|
case 2: {
|
||||||
|
switch (group_size) {
|
||||||
|
case 32:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
case 64:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
case 128:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
switch (group_size) {
|
||||||
|
case 32:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
case 64:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
case 128:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 8: {
|
||||||
|
switch (group_size) {
|
||||||
|
case 32:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
case 64:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
case 128:
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Quantization type not supported. Provided bits=" << bits
|
||||||
|
<< " and group_size=" << group_size
|
||||||
|
<< ". The supported options are bits in "
|
||||||
|
<< "{2, 4, 8} and group_size in {64, 128}.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void _qmm_dispatch(
|
||||||
|
array out,
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
int bits,
|
||||||
|
int group_size,
|
||||||
|
bool transposed_w) {
|
||||||
|
int K = x.shape(-1);
|
||||||
|
int M = x.size() / K;
|
||||||
|
int N = out.shape(-1);
|
||||||
|
|
||||||
|
switch (x.dtype()) {
|
||||||
|
case float32:
|
||||||
|
_qmm_dispatch_typed<float>(
|
||||||
|
out.data<float>(),
|
||||||
|
x.data<float>(),
|
||||||
|
w.data<uint32_t>(),
|
||||||
|
scales.data<float>(),
|
||||||
|
biases.data<float>(),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
transposed_w);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
_qmm_dispatch_typed<float16_t>(
|
||||||
|
out.data<float16_t>(),
|
||||||
|
x.data<float16_t>(),
|
||||||
|
w.data<uint32_t>(),
|
||||||
|
scales.data<float16_t>(),
|
||||||
|
biases.data<float16_t>(),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
transposed_w);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
_qmm_dispatch_typed<bfloat16_t>(
|
||||||
|
out.data<bfloat16_t>(),
|
||||||
|
x.data<bfloat16_t>(),
|
||||||
|
w.data<uint32_t>(),
|
||||||
|
scales.data<bfloat16_t>(),
|
||||||
|
biases.data<bfloat16_t>(),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
transposed_w);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[quantized_matmul] only floating types are supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 4);
|
||||||
|
|
||||||
|
auto& x_pre = inputs[0];
|
||||||
|
auto& w_pre = inputs[1];
|
||||||
|
auto& scales_pre = inputs[2];
|
||||||
|
auto& biases_pre = inputs[3];
|
||||||
|
|
||||||
|
auto ensure_row_contiguous = [](const array& arr) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy(arr, arr_copy, CopyType::General);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto x = ensure_row_contiguous(x_pre);
|
||||||
|
auto w = ensure_row_contiguous(w_pre);
|
||||||
|
auto scales = ensure_row_contiguous(scales_pre);
|
||||||
|
auto biases = ensure_row_contiguous(biases_pre);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@@ -126,7 +126,7 @@ struct ReductionPlan {
|
|||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
(x.flags().row_contiguous || x.flags().col_contiguous)) {
|
x.flags().contiguous) {
|
||||||
return ContiguousAllReduce;
|
return ContiguousAllReduce;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
14
mlx/backend/common/rope.cpp
Normal file
14
mlx/backend/common/rope.cpp
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/fast.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
void RoPE::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
throw std::runtime_error("NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::fast
|
@@ -53,7 +53,12 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Make sure that the last dimension is contiguous
|
// Make sure that the last dimension is contiguous
|
||||||
auto check_input = [](array x) {
|
auto check_input = [](array x) {
|
||||||
if (x.strides().back() == 1) {
|
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||||
|
if (x.ndim() > 1) {
|
||||||
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
|
}
|
||||||
|
if (no_copy) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
@@ -53,15 +53,35 @@ struct SignOp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct RoundOp {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::rint(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {std::rint(x.real()), std::rint(x.imag())};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void set_unary_output_data(const array& in, array& out) {
|
||||||
|
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
auto size = in.data_size();
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(size * out.itemsize()),
|
||||||
|
size,
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
void unary_op(const array& a, array& out, Op op) {
|
void unary_op(const array& a, array& out, Op op) {
|
||||||
const T* a_ptr = a.data<T>();
|
const T* a_ptr = a.data<T>();
|
||||||
if (a.flags().contiguous) {
|
if (a.flags().contiguous) {
|
||||||
out.set_data(
|
set_unary_output_data(a, out);
|
||||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
|
||||||
a.data_size(),
|
|
||||||
a.strides(),
|
|
||||||
a.flags());
|
|
||||||
T* dst = out.data<T>();
|
T* dst = out.data<T>();
|
||||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||||
dst[i] = op(a_ptr[i]);
|
dst[i] = op(a_ptr[i]);
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user